Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions effectful/handlers/jax/monoid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import jax

import effectful.handlers.jax.numpy as jnp
from effectful.handlers.jax import bind_dims, unbind_dims
from effectful.handlers.jax.scipy.special import logsumexp
from effectful.ops.monoid import (
CommutativeMonoid,
CommutativeMonoidWithZero,
Monoid,
Semilattice,
Streams,
distributes_over,
outer_stream,
)
from effectful.ops.semantics import evaluate, handler, typeof
from effectful.ops.syntax import deffn
from effectful.ops.types import Operation


@Operation.define
def cartesian_prod(x, y):
if x.ndim == 1:
x = x[:, None]
if y.ndim == 1:
y = y[:, None]
x, y = jnp.repeat(x, y.shape[0], axis=0), jnp.tile(y, (x.shape[0], 1))
return jnp.hstack([x, y])


Sum = CommutativeMonoid(kernel=jnp.add, identity=jnp.asarray(0))
Product = CommutativeMonoidWithZero(
kernel=jnp.multiply, identity=jnp.asarray(1), zero=jnp.asarray(0)
)
Min = Semilattice(kernel=jnp.minimum, identity=jnp.asarray(float("-inf")))
Max = Semilattice(kernel=jnp.maximum, identity=jnp.asarray(float("inf")))
LogSumExp = CommutativeMonoid(kernel=jnp.logaddexp, identity=jnp.asarray(float("-inf")))
CartesianProd = Monoid(kernel=cartesian_prod, identity=jnp.array([]))

distributes_over.register(Max.plus, Min.plus)
distributes_over.register(Min.plus, Max.plus)
distributes_over.register(Sum.plus, Min.plus)
distributes_over.register(Sum.plus, Max.plus)
distributes_over.register(Product.plus, Sum.plus)
distributes_over.register(Sum.plus, LogSumExp.plus)

ARRAY_REDUCE = {
Sum.plus: jnp.sum,
Product.plus: jnp.prod,
Min.plus: jnp.min,
Max.plus: jnp.max,
LogSumExp.plus: logsumexp,
}


@Monoid.reduce.register(jax.Array)
def _reduce_array(self, body: jax.Array, streams: Streams):
reductor = ARRAY_REDUCE[self.plus]
index = Operation.define(jax.Array)

if not streams:
return self.identity

# find and reduce an array stream
for stream_key, stream_body, streams_tail in outer_stream(streams):
if typeof(stream_body) != jax.Array:
continue

with handler({stream_key: deffn(unbind_dims(stream_body, index))}):
(eval_body, eval_streams_tail) = evaluate(body), evaluate(streams_tail)
assert isinstance(eval_streams_tail, dict)

reduce_tail = (
self.reduce(eval_body, eval_streams_tail)
if len(eval_streams_tail) > 0
else eval_body
)
return reductor(bind_dims(reduce_tail, index), axis=0)

return self._reduce_object(body, streams)
119 changes: 69 additions & 50 deletions effectful/ops/monoid.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,32 @@
import numbers
import typing
from collections import Counter, defaultdict
from collections.abc import Callable, Generator, Iterable, Iterator, Mapping, Sequence
from collections.abc import Callable, Generator, Iterable, Mapping, Sequence
from dataclasses import dataclass
from graphlib import TopologicalSorter
from typing import Annotated, Any

from effectful.internals.disjoint_set import DisjointSet
from effectful.ops.semantics import coproduct, evaluate, fvsof, fwd, handler
from effectful.ops.semantics import coproduct, evaluate, fvsof, fwd, handler, typeof
from effectful.ops.syntax import (
ObjectInterpretation,
Scoped,
_NumberTerm,
defdata,
deffn,
implements,
iter_,
syntactic_eq,
syntactic_hash,
)
from effectful.ops.types import Interpretation, NotHandled, Operation, Term
from effectful.ops.types import (
Expr,
Interpretation,
NotHandled,
Operation,
Term,
_CustomSingleDispatchCallable,
)

# Note: The streams value type should be something like Iterable[T], but some of
# our target stream types (e.g. jax.Array) are not subtypes of Iterable
Expand All @@ -36,17 +44,21 @@
)


def order_streams[T](streams: Streams[T]) -> Iterable[tuple[Operation[[], T], Any]]:
"""Determine an order to evaluate the streams based on their dependencies"""
def outer_stream(
streams: Streams,
) -> Iterable[tuple[Operation, Expr, dict[Operation, Expr]]]:
"""Returns the streams that can be ordered outermost in the loop nest as
well as the remaining streams in the nest.

"""
stream_vars = set(streams.keys())
dependencies = {k: fvsof(v) & stream_vars for k, v in streams.items()}
topo = TopologicalSorter(dependencies)
pred = {k: fvsof(v) & stream_vars for k, v in streams.items()}
topo = TopologicalSorter(pred)
topo.prepare()
while topo.is_active():
node_group = topo.get_ready()
for op in sorted(node_group):
yield (op, streams[op])
topo.done(*node_group)
return (
(op, streams[op], {k: v for (k, v) in streams.items() if k != op})
for op in topo.get_ready()
)


class Monoid[T]:
Expand Down Expand Up @@ -114,58 +126,65 @@ def _(self, *args):
return result

@Operation.define
@functools.singledispatchmethod
@_CustomSingleDispatchCallable # type: ignore[arg-type]
def reduce[A, B, U: Body](
dispatch,
self,
/,
body: Annotated[U, Scoped[A | B]],
streams: Annotated[Streams, Scoped[A]],
) -> Annotated[U, Scoped[B]]:
if callable(body):
return typing.cast(U, lambda *a, **k: self.reduce(body(*a, **k), streams))
return dispatch(typeof(body))(self, body, streams) # type: ignore[operator]

def generator(loop_order) -> Iterator[Interpretation]:
if len(loop_order) == 0:
return
@reduce.register(object) # type: ignore[attr-defined]
def _reduce_object(self, body: object, streams: Streams):
if not streams:
return self.identity

stream_key = loop_order[0][0]
stream_values = evaluate(streams[stream_key])
stream_values_iter = iter(stream_values) # type: ignore[arg-type]
# find and reduce a ground stream
for stream_key, stream_body, streams_tail in outer_stream(streams):
if isinstance(stream_body, Term):
continue

# If we try to iterate and get a term instead of a real
# iterator, give up
stream_values_iter = iter(stream_body)

# if we iterate and get a term instead of a real iterator, skip
if isinstance(stream_values_iter, Term) and stream_values_iter.op is iter_:
raise NotHandled

if len(loop_order) == 1:
for val in stream_values_iter:
yield {stream_key: functools.partial(lambda v: v, val)}
else:
for val in stream_values_iter:
intp: Interpretation = {
stream_key: functools.partial(lambda v: v, val)
}
with handler(intp):
for intp2 in generator(loop_order[1:]):
yield coproduct(intp, intp2)

loop_order = list(order_streams(streams))
try:
return self.plus(
*(handler(intp)(evaluate)(body) for intp in generator(loop_order))
)
except NotHandled:
return typing.cast(U, defdata(self.reduce, body, streams))
continue

@reduce.register # type: ignore[attr-defined]
def _(self, body: Mapping, streams):
new_reduces = []
for stream_val in stream_values_iter:
with handler({stream_key: deffn(stream_val)}):
eval_args = evaluate((body, streams_tail))
assert isinstance(eval_args, tuple)
new_reduces.append(self.reduce(*eval_args))

return self.plus(*new_reduces)

return defdata(self.reduce, body, streams)

@reduce.register(Callable) # type: ignore[attr-defined]
def _reduce_callable(self, body: Callable, streams):
if isinstance(body, Term):
return defdata(self.reduce, body, streams)
return lambda *a, **k: self.reduce(body(*a, **k), streams)

@reduce.register(Mapping) # type: ignore[attr-defined]
def _reduce_mapping(self, body: Mapping, streams):
if isinstance(body, Term):
return defdata(self.reduce, body, streams)
return {k: self.reduce(v, streams) for (k, v) in body.items()}

@reduce.register # type: ignore[attr-defined]
def _(self, body: Sequence, streams):
@reduce.register(list | tuple) # type: ignore[attr-defined]
def _reduce_sequence(self, body: Sequence, streams):
if isinstance(body, Term):
return defdata(self.reduce, body, streams)
return type(body)(self.reduce(x, streams) for x in body) # type:ignore[call-arg]

@reduce.register # type: ignore[attr-defined]
def _(self, body: Generator, streams):
@reduce.register(Generator) # type: ignore[attr-defined]
def _reduce_generator(self, body: Generator, streams):
if isinstance(body, Term):
return defdata(self.reduce, body, streams)
return (self.reduce(x, streams) for x in body)


Expand Down
Loading
Loading