-
Notifications
You must be signed in to change notification settings - Fork 4
Add inversion from weighted
#655
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: staging-weighted
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,24 +4,26 @@ | |
| import numbers | ||
| import typing | ||
| from collections import Counter, defaultdict | ||
| from collections.abc import Callable, Generator, Iterable, Iterator, Mapping, Sequence | ||
| from collections.abc import Callable, Generator, Iterable, Iterator, Mapping | ||
| from dataclasses import dataclass | ||
| from graphlib import TopologicalSorter | ||
| from typing import Annotated, Any | ||
|
|
||
| from effectful.internals.disjoint_set import DisjointSet | ||
| from effectful.internals.runtime import interpreter | ||
| from effectful.ops.semantics import coproduct, evaluate, fvsof, fwd, handler | ||
| from effectful.ops.syntax import ( | ||
| ObjectInterpretation, | ||
| Scoped, | ||
| _NumberTerm, | ||
| defdata, | ||
| deffn, | ||
| implements, | ||
| iter_, | ||
| syntactic_eq, | ||
| syntactic_hash, | ||
| ) | ||
| from effectful.ops.types import Interpretation, NotHandled, Operation, Term | ||
| from effectful.ops.types import Expr, Interpretation, NotHandled, Operation, Term | ||
|
|
||
| # Note: The streams value type should be something like Iterable[T], but some of | ||
| # our target stream types (e.g. jax.Array) are not subtypes of Iterable | ||
|
|
@@ -80,9 +82,13 @@ def plus[S: Body[T]](self, *args: S) -> S: | |
| def _plus[S](self, *args: S) -> S: | ||
| return typing.cast(S, functools.reduce(self.kernel, args, self.identity)) | ||
|
|
||
| @_plus.register(Sequence) | ||
| @_plus.register(tuple) | ||
| def _(self, *args): | ||
| return type(args[0])(self.plus(*vs) for vs in zip(*args, strict=True)) | ||
| return tuple(self.plus(*vs) for vs in zip(*args, strict=True)) | ||
|
|
||
| @_plus.register(Generator) | ||
| def _(self, *args): | ||
| return (self.plus(*vs) for vs in zip(*args, strict=True)) | ||
|
|
||
| @_plus.register(Mapping) | ||
| def _(self, *args): | ||
|
|
@@ -161,8 +167,8 @@ def _(self, body: Mapping, streams): | |
| return {k: self.reduce(v, streams) for (k, v) in body.items()} | ||
|
|
||
| @reduce.register # type: ignore[attr-defined] | ||
| def _(self, body: Sequence, streams): | ||
| return type(body)(self.reduce(x, streams) for x in body) # type:ignore[call-arg] | ||
| def _(self, body: tuple, streams): | ||
| return tuple(self.reduce(x, streams) for x in body) | ||
|
|
||
| @reduce.register # type: ignore[attr-defined] | ||
| def _(self, body: Generator, streams): | ||
|
|
@@ -252,12 +258,26 @@ def _arg_max[T]( | |
| return b if b[0] > a[0] else a # type: ignore | ||
|
|
||
|
|
||
| @Operation.define | ||
| def product[T]( | ||
| a: Iterable[tuple[T, ...] | T], b: Iterable[tuple[T, ...] | T] | ||
| ) -> Iterable[tuple[T, ...]]: | ||
| if isinstance(a, Term) or isinstance(b, Term): | ||
| raise NotHandled | ||
|
|
||
| def to_tuple(x): | ||
| return x if isinstance(x, tuple) else (x,) | ||
|
|
||
| return [to_tuple(x) + to_tuple(y) for (x, y) in itertools.product(a, b)] | ||
|
|
||
|
|
||
| Min = Semilattice(kernel=min, identity=float("inf")) | ||
| Max = Semilattice(kernel=max, identity=float("-inf")) | ||
| ArgMin = Monoid(kernel=_arg_min, identity=(float("inf"), None)) | ||
| ArgMax = Monoid(kernel=_arg_max, identity=(float("-inf"), None)) | ||
| Sum = CommutativeMonoid(kernel=_NumberTerm.__add__, identity=0) | ||
| Product = CommutativeMonoidWithZero(kernel=_NumberTerm.__mul__, identity=1, zero=0) | ||
| CartesianProduct = Monoid(kernel=product, identity=[()]) | ||
|
|
||
|
|
||
| @dataclass | ||
|
|
@@ -545,11 +565,136 @@ def reduce(self, monoid, body, streams): | |
| return fwd() | ||
|
|
||
|
|
||
| def inner_stream( | ||
| streams: dict[Operation, Expr], | ||
| ) -> Iterable[tuple[dict[Operation, Expr], Operation, Expr]]: | ||
| """Returns the streams that can be ordered innermost in the loop nest as | ||
| well as the remaining streams in the nest. | ||
|
|
||
| """ | ||
| stream_vars = set(streams.keys()) | ||
|
|
||
| no_dependents = set() | ||
| succ = defaultdict(set) | ||
| for k, v in streams.items(): | ||
| for pred in fvsof(v) & stream_vars: | ||
| succ[pred].add(k) | ||
| else: | ||
| no_dependents.add(k) | ||
|
|
||
| topo = TopologicalSorter(succ) | ||
| topo.prepare() | ||
| return ( | ||
| ({k: v for (k, v) in streams.items() if k != op}, op, streams[op]) | ||
| for op in set(topo.get_ready()) | no_dependents | ||
| ) | ||
|
|
||
|
|
||
| def match_reduce(term: Term) -> tuple | None: | ||
| reduce_args = None | ||
|
|
||
| def set_reduce_args(*args, **kwargs): | ||
| nonlocal reduce_args | ||
| reduce_args = args | ||
|
|
||
| with interpreter({Monoid.reduce: set_reduce_args}): | ||
| term.op(*term.args, **term.kwargs) | ||
| return reduce_args | ||
|
|
||
|
|
||
| class ReduceDistributeCartesianProduct(ObjectInterpretation): | ||
| """Eliminates a reduce over a cartesian product. | ||
| ∑_x₁ ∑_x₂ ... ∑_xₙ ∏_i f(xᵢ) = ∏_i ∑_xᵢ f(xᵢ) | ||
| This transform is also called inversion in the lifting | ||
| literature (e.g. [1]). | ||
|
|
||
| More specifically, this transform implements the identity | ||
| reduce(⨁, reduce(⨂, body2, {vv: v()}), {v: reduce(×, body1, S1)} ∪ S2) | ||
| = reduce(⨁, reduce(⨂, reduce(⨁, body2, {vv: v()}), S1), S2) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This equation is quite hard to parse and appears to be wrong (what binds
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It should read: |
||
| where × is the cartesian product and ⨂ distributes over ⨁. | ||
|
|
||
| Note: This could be generalized to grouped inversion [2]. | ||
|
|
||
| [1] Braz, Rd, Eyal Amir, and Dan Roth. "Lifted first-order | ||
| probabilistic inference." IJCAI. 2005. | ||
| [2] Taghipour, Nima, et al. "Completeness results for lifted | ||
| variable elimination." AISTATS. 2013. | ||
| """ | ||
|
|
||
| @implements(CommutativeMonoid.reduce) | ||
| def reduce(self, sum_monoid: Monoid, sum_body, sum_streams): | ||
| if not (isinstance(sum_body, Term)): | ||
| return fwd() | ||
|
|
||
| # body is a product or multiplication of products | ||
| if distributes_over(sum_body.op, sum_monoid.plus): | ||
| prod_reduces = sum_body.args | ||
| else: | ||
| prod_reduces = [sum_body] | ||
|
|
||
| products: list[tuple[Monoid, Callable, Operation, Term]] = [] | ||
| for prod_reduce in prod_reduces: | ||
| prod_args = match_reduce(prod_reduce) | ||
| if prod_args is None: | ||
| return fwd() | ||
| (prod_monoid, prod_body, prod_streams) = prod_args | ||
| if not ( | ||
| distributes_over(prod_monoid.plus, sum_monoid.plus) | ||
| and (len(products) == 0 or products[-1][0] == prod_monoid) | ||
| ): | ||
| return fwd() | ||
|
|
||
| if len(prod_streams) > 1 or len(prod_streams) == 0: | ||
| return fwd() | ||
| (prod_op, prod_stream) = next(iter(prod_streams.items())) | ||
| products.append( | ||
| (prod_monoid, deffn(prod_body, prod_op), prod_op, prod_stream) | ||
| ) | ||
|
|
||
| assert len(products) > 0 | ||
|
|
||
| for outer_sum_streams, cprod_op, cprod_term in inner_stream(sum_streams): | ||
| if not ( | ||
| isinstance(cprod_term, Term) | ||
| and cprod_term.op == CartesianProduct.reduce | ||
| ): | ||
| continue | ||
| (cprod_body, cprod_streams) = cprod_term.args | ||
|
|
||
| if not all( | ||
| prod_stream.op == cprod_op for (_, _, _, prod_stream) in products | ||
| ): | ||
| continue | ||
|
|
||
| prod_op = Operation.define(products[0][2]) | ||
| inner_sum = sum_monoid.reduce( | ||
| Product.plus( | ||
| *(prod_body(prod_op()) for (_, prod_body, _, _) in products) | ||
| ), | ||
| {prod_op: cprod_body}, | ||
| ) | ||
| prod = prod_monoid.reduce(inner_sum, cprod_streams) | ||
| outer_sum = ( | ||
| sum_monoid.reduce(prod, outer_sum_streams) | ||
| if outer_sum_streams | ||
| else prod | ||
| ) | ||
| return outer_sum | ||
|
|
||
| return fwd() | ||
|
|
||
|
|
||
| NormalizeReduceIntp = functools.reduce( | ||
| coproduct, | ||
| typing.cast( | ||
| list[Interpretation], | ||
| [ReduceNoStreams(), ReduceFusion(), ReduceSplit(), ReduceFactorization()], | ||
| [ | ||
| ReduceNoStreams(), | ||
| ReduceFusion(), | ||
| ReduceSplit(), | ||
| ReduceFactorization(), | ||
| ReduceDistributeCartesianProduct(), | ||
| ], | ||
| ), | ||
| ) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a very strange function. Surely there's a more comprehensible way to write this, even if it ends up being more verbose.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed. Unfortunately, I'm not sure how to write this in a way that doesn't enumerate operations. The subclassing relation between operations is implicit, and we want to match all subclasses (and leave open the possibility that new ones could be created).