Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion effectful/internals/product_n.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def map_structure(func, expr):
else:
return type(expr)(map_structure(func, tuple(expr.items())))
elif isinstance(expr, collections.abc.Sequence):
if isinstance(expr, str | bytes):
if isinstance(expr, str | bytes | range):
return expr
elif (
isinstance(expr, tuple)
Expand Down
159 changes: 152 additions & 7 deletions effectful/ops/monoid.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,26 @@
import numbers
import typing
from collections import Counter, defaultdict
from collections.abc import Callable, Generator, Iterable, Iterator, Mapping, Sequence
from collections.abc import Callable, Generator, Iterable, Iterator, Mapping
from dataclasses import dataclass
from graphlib import TopologicalSorter
from typing import Annotated, Any

from effectful.internals.disjoint_set import DisjointSet
from effectful.internals.runtime import interpreter
from effectful.ops.semantics import coproduct, evaluate, fvsof, fwd, handler
from effectful.ops.syntax import (
ObjectInterpretation,
Scoped,
_NumberTerm,
defdata,
deffn,
implements,
iter_,
syntactic_eq,
syntactic_hash,
)
from effectful.ops.types import Interpretation, NotHandled, Operation, Term
from effectful.ops.types import Expr, Interpretation, NotHandled, Operation, Term

# Note: The streams value type should be something like Iterable[T], but some of
# our target stream types (e.g. jax.Array) are not subtypes of Iterable
Expand Down Expand Up @@ -80,9 +82,13 @@ def plus[S: Body[T]](self, *args: S) -> S:
def _plus[S](self, *args: S) -> S:
return typing.cast(S, functools.reduce(self.kernel, args, self.identity))

@_plus.register(Sequence)
@_plus.register(tuple)
def _(self, *args):
return type(args[0])(self.plus(*vs) for vs in zip(*args, strict=True))
return tuple(self.plus(*vs) for vs in zip(*args, strict=True))

@_plus.register(Generator)
def _(self, *args):
return (self.plus(*vs) for vs in zip(*args, strict=True))

@_plus.register(Mapping)
def _(self, *args):
Expand Down Expand Up @@ -161,8 +167,8 @@ def _(self, body: Mapping, streams):
return {k: self.reduce(v, streams) for (k, v) in body.items()}

@reduce.register # type: ignore[attr-defined]
def _(self, body: Sequence, streams):
return type(body)(self.reduce(x, streams) for x in body) # type:ignore[call-arg]
def _(self, body: tuple, streams):
return tuple(self.reduce(x, streams) for x in body)

@reduce.register # type: ignore[attr-defined]
def _(self, body: Generator, streams):
Expand Down Expand Up @@ -252,12 +258,26 @@ def _arg_max[T](
return b if b[0] > a[0] else a # type: ignore


@Operation.define
def product[T](
a: Iterable[tuple[T, ...] | T], b: Iterable[tuple[T, ...] | T]
) -> Iterable[tuple[T, ...]]:
if isinstance(a, Term) or isinstance(b, Term):
raise NotHandled

def to_tuple(x):
return x if isinstance(x, tuple) else (x,)

return [to_tuple(x) + to_tuple(y) for (x, y) in itertools.product(a, b)]


Min = Semilattice(kernel=min, identity=float("inf"))
Max = Semilattice(kernel=max, identity=float("-inf"))
ArgMin = Monoid(kernel=_arg_min, identity=(float("inf"), None))
ArgMax = Monoid(kernel=_arg_max, identity=(float("-inf"), None))
Sum = CommutativeMonoid(kernel=_NumberTerm.__add__, identity=0)
Product = CommutativeMonoidWithZero(kernel=_NumberTerm.__mul__, identity=1, zero=0)
CartesianProduct = Monoid(kernel=product, identity=[()])


@dataclass
Expand Down Expand Up @@ -545,11 +565,136 @@ def reduce(self, monoid, body, streams):
return fwd()


def inner_stream(
streams: dict[Operation, Expr],
) -> Iterable[tuple[dict[Operation, Expr], Operation, Expr]]:
"""Returns the streams that can be ordered innermost in the loop nest as
well as the remaining streams in the nest.

"""
stream_vars = set(streams.keys())

no_dependents = set()
succ = defaultdict(set)
for k, v in streams.items():
for pred in fvsof(v) & stream_vars:
succ[pred].add(k)
else:
no_dependents.add(k)

topo = TopologicalSorter(succ)
topo.prepare()
return (
({k: v for (k, v) in streams.items() if k != op}, op, streams[op])
for op in set(topo.get_ready()) | no_dependents
)


def match_reduce(term: Term) -> tuple | None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a very strange function. Surely there's a more comprehensible way to write this, even if it ends up being more verbose.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. Unfortunately, I'm not sure how to write this in a way that doesn't enumerate operations. The subclassing relation between operations is implicit, and we want to match all subclasses (and leave open the possibility that new ones could be created).

reduce_args = None

def set_reduce_args(*args, **kwargs):
nonlocal reduce_args
reduce_args = args

with interpreter({Monoid.reduce: set_reduce_args}):
term.op(*term.args, **term.kwargs)
return reduce_args


class ReduceDistributeCartesianProduct(ObjectInterpretation):
"""Eliminates a reduce over a cartesian product.
∑_x₁ ∑_x₂ ... ∑_xₙ ∏_i f(xᵢ) = ∏_i ∑_xᵢ f(xᵢ)
This transform is also called inversion in the lifting
literature (e.g. [1]).

More specifically, this transform implements the identity
reduce(⨁, reduce(⨂, body2, {vv: v()}), {v: reduce(×, body1, S1)} ∪ S2)
= reduce(⨁, reduce(⨂, reduce(⨁, body2, {vv: v()}), S1), S2)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This equation is quite hard to parse and appears to be wrong (what binds v on the RHS?). Is it out of date wrt what is now happening in the rule?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should read:

reduce(⨁, reduce(⨂, body2, {vv: v()}), {v: reduce(×, body1, S1)} ∪ S2)
        = reduce(⨁, reduce(⨂, reduce(⨁, body2, {vv: body1}), S1), S2)

where × is the cartesian product and ⨂ distributes over ⨁.

Note: This could be generalized to grouped inversion [2].

[1] Braz, Rd, Eyal Amir, and Dan Roth. "Lifted first-order
probabilistic inference." IJCAI. 2005.
[2] Taghipour, Nima, et al. "Completeness results for lifted
variable elimination." AISTATS. 2013.
"""

@implements(CommutativeMonoid.reduce)
def reduce(self, sum_monoid: Monoid, sum_body, sum_streams):
if not (isinstance(sum_body, Term)):
return fwd()

# body is a product or multiplication of products
if distributes_over(sum_body.op, sum_monoid.plus):
prod_reduces = sum_body.args
else:
prod_reduces = [sum_body]

products: list[tuple[Monoid, Callable, Operation, Term]] = []
for prod_reduce in prod_reduces:
prod_args = match_reduce(prod_reduce)
if prod_args is None:
return fwd()
(prod_monoid, prod_body, prod_streams) = prod_args
if not (
distributes_over(prod_monoid.plus, sum_monoid.plus)
and (len(products) == 0 or products[-1][0] == prod_monoid)
):
return fwd()

if len(prod_streams) > 1 or len(prod_streams) == 0:
return fwd()
(prod_op, prod_stream) = next(iter(prod_streams.items()))
products.append(
(prod_monoid, deffn(prod_body, prod_op), prod_op, prod_stream)
)

assert len(products) > 0

for outer_sum_streams, cprod_op, cprod_term in inner_stream(sum_streams):
if not (
isinstance(cprod_term, Term)
and cprod_term.op == CartesianProduct.reduce
):
continue
(cprod_body, cprod_streams) = cprod_term.args

if not all(
prod_stream.op == cprod_op for (_, _, _, prod_stream) in products
):
continue

prod_op = Operation.define(products[0][2])
inner_sum = sum_monoid.reduce(
Product.plus(
*(prod_body(prod_op()) for (_, prod_body, _, _) in products)
),
{prod_op: cprod_body},
)
prod = prod_monoid.reduce(inner_sum, cprod_streams)
outer_sum = (
sum_monoid.reduce(prod, outer_sum_streams)
if outer_sum_streams
else prod
)
return outer_sum

return fwd()


NormalizeReduceIntp = functools.reduce(
coproduct,
typing.cast(
list[Interpretation],
[ReduceNoStreams(), ReduceFusion(), ReduceSplit(), ReduceFactorization()],
[
ReduceNoStreams(),
ReduceFusion(),
ReduceSplit(),
ReduceFactorization(),
ReduceDistributeCartesianProduct(),
],
),
)

Expand Down
1 change: 1 addition & 0 deletions effectful/ops/semantics.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def evaluate[T](
@evaluate.register(object)
@evaluate.register(str)
@evaluate.register(bytes)
@evaluate.register(range)
def _evaluate_object[T](expr: T, **kwargs) -> T:
if dataclasses.is_dataclass(expr) and not isinstance(expr, type):
return typing.cast(
Expand Down
5 changes: 4 additions & 1 deletion effectful/ops/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,10 @@ def _instance_op(instance, *args, **kwargs):
else:
return default_result

instance_op = self.define(types.MethodType(_instance_op, instance))
name = ("" if owner is None else f"{owner.__name__}_") + self.__name__
Comment thread
eb8680 marked this conversation as resolved.
instance_op = self.define(
types.MethodType(_instance_op, instance), name=name
)
instance.__dict__[self._name_on_instance] = instance_op
return instance_op
elif instance is not None:
Expand Down
14 changes: 7 additions & 7 deletions tests/_monoid_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ def _value_strategy_for(annotation: Any) -> st.SearchStrategy[Any]:
if annotation is float:
return st.floats(allow_nan=False)
if get_origin(annotation) is list and get_args(annotation) == (int,):
return st.lists(st.integers())
return st.lists(st.integers(), max_size=3)
raise NotImplementedError(
f"No value strategy for return annotation {annotation!r}; "
"supported: int, list[int]"
)


_UNARY_INT_FNS: list[Callable[[int], int]] = [
_UNARY_NUM_FNS: list[Callable[[int], int]] = [
lambda x: x,
lambda x: x + 1,
lambda x: x - 1,
Expand All @@ -30,7 +30,7 @@ def _value_strategy_for(annotation: Any) -> st.SearchStrategy[Any]:
lambda x: 3 * x + 1,
]

_BINARY_INT_FNS: list[Callable[[int, int], int]] = [
_BINARY_NUM_FNS: list[Callable[[int, int], int]] = [
lambda x, y: x + y,
lambda x, y: x - y,
lambda x, y: x * y,
Expand Down Expand Up @@ -58,10 +58,10 @@ def _strategy_for_op(op: Operation) -> st.SearchStrategy[Callable[..., Any]]:

if not params:
return _value_strategy_for(ret).map(deffn)
if ret is int and param_types == (int,):
return st.sampled_from(_UNARY_INT_FNS)
if ret is int and param_types == (int, int):
return st.sampled_from(_BINARY_INT_FNS)
if ret in (int, float) and param_types == (int,):
return st.sampled_from(_UNARY_NUM_FNS)
if ret in (int, float) and param_types == (int, int):
return st.sampled_from(_BINARY_NUM_FNS)
if get_origin(ret) is list and get_args(ret) == (int,) and param_types == (int,):
return st.sampled_from(_UNARY_LIST_FNS)
raise NotImplementedError(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_handlers_llm_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def test_agent_tool_names_are_valid_integration():
agent = _ToolNameAgent()
template = agent.ask
tools = template.tools
expected_helper_tool_name = f"self__{agent.helper.__name__}"
expected_helper_tool_name = "self__helper"
Comment thread
eb8680 marked this conversation as resolved.
assert tools
assert expected_helper_tool_name in tools
assert all(re.fullmatch(r"[a-zA-Z0-9_-]+", name) for name in tools)
Expand Down
Loading
Loading