-
Notifications
You must be signed in to change notification settings - Fork 7
GEM: Simplify Indexed tensors #131
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
de501c1
592061e
03e9ae6
158b0ec
d2a4584
b49eb5c
340ec40
a7fda02
eab0d90
41981ec
7f58028
7077cf9
55bb313
10a8604
b738ec6
830a0e3
5dfd17d
3fe803d
c0a6e02
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -293,7 +293,7 @@ def is_equal(self, other): | |
| return False | ||
| if self.shape != other.shape: | ||
| return False | ||
| return tuple(self.array.flat) == tuple(other.array.flat) | ||
| return numpy.array_equal(self.array, other.array) | ||
|
|
||
| def get_hash(self): | ||
| return hash((type(self), self.shape, tuple(self.array.flat))) | ||
|
|
@@ -684,12 +684,46 @@ def __new__(cls, aggregate, multiindex): | |
| if isinstance(aggregate, Zero): | ||
| return Zero(dtype=aggregate.dtype) | ||
|
|
||
| # All indices fixed | ||
| if all(isinstance(i, int) for i in multiindex): | ||
| if isinstance(aggregate, Constant): | ||
| return Literal(aggregate.array[multiindex], dtype=aggregate.dtype) | ||
| elif isinstance(aggregate, ListTensor): | ||
| return aggregate.array[multiindex] | ||
| # Simplify Literal and ListTensor | ||
| if isinstance(aggregate, (Constant, ListTensor)): | ||
| if all(isinstance(i, int) for i in multiindex): | ||
| # All indices fixed | ||
| sub = aggregate.array[multiindex] | ||
| return Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else sub | ||
|
|
||
| elif any(isinstance(i, int) for i in multiindex) and all(isinstance(i, (int, Index)) for i in multiindex): | ||
| # Some indices fixed | ||
| slices = tuple(i if isinstance(i, int) else slice(None) for i in multiindex) | ||
| sub = aggregate.array[slices] | ||
| sub = Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else ListTensor(sub) | ||
| return Indexed(sub, tuple(i for i in multiindex if not isinstance(i, int))) | ||
|
|
||
| # Simplify Indexed(ComponentTensor(Indexed(C, kk), jj), ii) -> Indexed(C, ll) | ||
| if isinstance(aggregate, ComponentTensor): | ||
| B, = aggregate.children | ||
| jj = aggregate.multiindex | ||
| ii = multiindex | ||
|
|
||
| if isinstance(B, Indexed): | ||
| C, = B.children | ||
| kk = B.multiindex | ||
| if not isinstance(C, ComponentTensor): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need to avoid
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it is because for |
||
| rep = dict(zip(jj, ii)) | ||
| ll = tuple(rep.get(k, k) for k in kk) | ||
| B = Indexed(C, ll) | ||
| jj = tuple(j for j in jj if j not in kk) | ||
| ii = tuple(rep[j] for j in jj) | ||
| if not ii: | ||
| return B | ||
|
|
||
| if isinstance(B, Indexed): | ||
| C, = B.children | ||
| kk = B.multiindex | ||
| ff = C.free_indices | ||
| if all((j in kk) and (j not in ff) for j in jj): | ||
| rep = dict(zip(jj, ii)) | ||
| ll = tuple(rep.get(k, k) for k in kk) | ||
| return Indexed(C, ll) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This also causes recursion. |
||
|
|
||
| self = super(Indexed, cls).__new__(cls) | ||
| self.children = (aggregate,) | ||
|
|
@@ -835,6 +869,11 @@ def __new__(cls, expression, multiindex): | |
| if isinstance(expression, Zero): | ||
| return Zero(shape, dtype=expression.dtype) | ||
|
|
||
| # Index folding | ||
| if isinstance(expression, Indexed): | ||
| if multiindex == expression.multiindex: | ||
| return expression.children[0] | ||
|
|
||
| self = super(ComponentTensor, cls).__new__(cls) | ||
| self.children = (expression,) | ||
| self.multiindex = multiindex | ||
|
|
@@ -871,6 +910,11 @@ def __new__(cls, summand, multiindex): | |
| if not multiindex: | ||
| return summand | ||
|
|
||
| # Flatten nested sums | ||
| if isinstance(summand, IndexSum): | ||
| A, = summand.children | ||
| return IndexSum(A, summand.multiindex + multiindex) | ||
|
|
||
| self = super(IndexSum, cls).__new__(cls) | ||
| self.children = (summand,) | ||
| self.multiindex = multiindex | ||
|
|
@@ -891,15 +935,40 @@ def __new__(cls, array): | |
| dtype = Node.inherit_dtype_from_children(tuple(array.flat)) | ||
|
|
||
| # Handle children with shape | ||
| child_shape = array.flat[0].shape | ||
| e0 = array.flat[0] | ||
| child_shape = e0.shape | ||
| assert all(elem.shape == child_shape for elem in array.flat) | ||
|
|
||
| # Simplify [v[multiindex, j] for j in range(n)] -> partial_indexed(v, multiindex) | ||
| if all(isinstance(elem, Indexed) for elem in array.flat): | ||
| tensor = e0.children[0] | ||
| multiindex = tuple(i for i in e0.multiindex if not isinstance(i, Integral)) | ||
| index_shape = tuple(i.extent for i in multiindex if isinstance(i, Index)) | ||
| if index_shape + array.shape + child_shape == tensor.shape: | ||
| if all(elem.children[0] == tensor for elem in array.flat[1:]): | ||
| if all(elem.multiindex == multiindex + idx for idx, elem in numpy.ndenumerate(array)): | ||
| return partial_indexed(tensor, multiindex) | ||
|
Comment on lines
+943
to
+950
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we should make it more explicit that we are only handling multiindex of the following pattern: |
||
|
|
||
| # Simplify [v[j, ...] for j in range(n)] -> v | ||
| if all(isinstance(elem, ComponentTensor) and isinstance(elem.children[0], Indexed) | ||
| for elem in array.flat): | ||
| tensor = e0.children[0].children[0] | ||
| if array.shape + child_shape == tensor.shape: | ||
| if all(elem.children[0].children[0] == tensor for elem in array.flat[1:]): | ||
| if all(elem.children[0].multiindex == idx + elem.multiindex | ||
| for idx, elem in numpy.ndenumerate(array)): | ||
| return tensor | ||
|
|
||
| # Flatten nested ListTensors | ||
| if all(isinstance(elem, ListTensor) for elem in array.flat): | ||
| return ListTensor(asarray([elem.array for elem in array.flat]).reshape(array.shape + child_shape)) | ||
|
|
||
| if child_shape: | ||
| # Destroy structure | ||
| direct_array = numpy.empty(array.shape + child_shape, dtype=object) | ||
| for alpha in numpy.ndindex(array.shape): | ||
| for alpha, elem in numpy.ndenumerate(array): | ||
| for beta in numpy.ndindex(child_shape): | ||
| direct_array[alpha + beta] = Indexed(array[alpha], beta) | ||
| direct_array[alpha + beta] = Indexed(elem, beta) | ||
| array = direct_array | ||
|
|
||
| # Constant folding | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,64 @@ | ||
| import pytest | ||
| import gem | ||
| import numpy | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def A(): | ||
| a = gem.Variable("a", ()) | ||
| b = gem.Variable("b", ()) | ||
| c = gem.Variable("c", ()) | ||
| d = gem.Variable("d", ()) | ||
| array = [[a, b], [c, d]] | ||
| A = gem.ListTensor(array) | ||
| return A | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def X(): | ||
| return gem.Variable("X", (2, 2)) | ||
|
|
||
|
|
||
| def test_listtensor_from_indexed(X): | ||
| k = gem.Index() | ||
| elems = [gem.Indexed(X, (k, *i)) for i in numpy.ndindex(X.shape[1:])] | ||
| tensor = gem.ListTensor(numpy.reshape(elems, X.shape[1:])) | ||
|
|
||
| assert isinstance(tensor, gem.ComponentTensor) | ||
| j = tensor.multiindex | ||
| expected = gem.partial_indexed(X, (k,)) | ||
| expected = gem.ComponentTensor(gem.Indexed(expected, j), j) | ||
| assert tensor == expected | ||
|
|
||
|
|
||
| def test_listtensor_from_fixed_indexed(A): | ||
| elems = [gem.Indexed(A, i) for i in numpy.ndindex(A.shape)] | ||
| tensor = gem.ListTensor(numpy.reshape(elems, A.shape)) | ||
| assert tensor == A | ||
|
|
||
|
|
||
| def test_listtensor_from_partial_indexed(A): | ||
| elems = [gem.partial_indexed(A, i) for i in numpy.ndindex(A.shape[:1])] | ||
| tensor = gem.ListTensor(elems) | ||
| assert tensor == A | ||
|
|
||
|
|
||
| def test_nested_partial_indexed(A): | ||
| i, j = gem.indices(2) | ||
| B = gem.partial_indexed(gem.partial_indexed(A, (i,)), (j,)) | ||
| assert B == gem.Indexed(A, (i, j)) | ||
|
|
||
|
|
||
| def test_componenttensor_from_indexed(A): | ||
| i, j = gem.indices(2) | ||
| Aij = gem.Indexed(A, (i, j)) | ||
| assert A == gem.ComponentTensor(Aij, (i, j)) | ||
|
|
||
|
|
||
| def test_flatten_indexsum(A): | ||
| i, j = gem.indices(2) | ||
| Aij = gem.Indexed(A, (i, j)) | ||
|
|
||
| result = gem.IndexSum(gem.IndexSum(Aij, (i,)), (j,)) | ||
| expected = gem.IndexSum(Aij, (i, j)) | ||
| assert result == expected |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if this is safe. This is a recursion, and, unlike when we use
DAGTraverser, the result is not cached.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the issue? Are we potentially creating new objects and not freeing memory for the old ones?