diff --git a/gem/gem.py b/gem/gem.py index b31fd950..8601c1c2 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -293,7 +293,7 @@ def is_equal(self, other): return False if self.shape != other.shape: return False - return tuple(self.array.flat) == tuple(other.array.flat) + return numpy.array_equal(self.array, other.array) def get_hash(self): return hash((type(self), self.shape, tuple(self.array.flat))) @@ -684,12 +684,46 @@ def __new__(cls, aggregate, multiindex): if isinstance(aggregate, Zero): return Zero(dtype=aggregate.dtype) - # All indices fixed - if all(isinstance(i, int) for i in multiindex): - if isinstance(aggregate, Constant): - return Literal(aggregate.array[multiindex], dtype=aggregate.dtype) - elif isinstance(aggregate, ListTensor): - return aggregate.array[multiindex] + # Simplify Literal and ListTensor + if isinstance(aggregate, (Constant, ListTensor)): + if all(isinstance(i, int) for i in multiindex): + # All indices fixed + sub = aggregate.array[multiindex] + return Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else sub + + elif any(isinstance(i, int) for i in multiindex) and all(isinstance(i, (int, Index)) for i in multiindex): + # Some indices fixed + slices = tuple(i if isinstance(i, int) else slice(None) for i in multiindex) + sub = aggregate.array[slices] + sub = Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else ListTensor(sub) + return Indexed(sub, tuple(i for i in multiindex if not isinstance(i, int))) + + # Simplify Indexed(ComponentTensor(Indexed(C, kk), jj), ii) -> Indexed(C, ll) + if isinstance(aggregate, ComponentTensor): + B, = aggregate.children + jj = aggregate.multiindex + ii = multiindex + + if isinstance(B, Indexed): + C, = B.children + kk = B.multiindex + if not isinstance(C, ComponentTensor): + rep = dict(zip(jj, ii)) + ll = tuple(rep.get(k, k) for k in kk) + B = Indexed(C, ll) + jj = tuple(j for j in jj if j not in kk) + ii = tuple(rep[j] for j in jj) + if not ii: + return B + + if isinstance(B, Indexed): + C, = B.children + kk = B.multiindex + ff = C.free_indices + if all((j in kk) and (j not in ff) for j in jj): + rep = dict(zip(jj, ii)) + ll = tuple(rep.get(k, k) for k in kk) + return Indexed(C, ll) self = super(Indexed, cls).__new__(cls) self.children = (aggregate,) @@ -835,6 +869,11 @@ def __new__(cls, expression, multiindex): if isinstance(expression, Zero): return Zero(shape, dtype=expression.dtype) + # Index folding + if isinstance(expression, Indexed): + if multiindex == expression.multiindex: + return expression.children[0] + self = super(ComponentTensor, cls).__new__(cls) self.children = (expression,) self.multiindex = multiindex @@ -871,6 +910,11 @@ def __new__(cls, summand, multiindex): if not multiindex: return summand + # Flatten nested sums + if isinstance(summand, IndexSum): + A, = summand.children + return IndexSum(A, summand.multiindex + multiindex) + self = super(IndexSum, cls).__new__(cls) self.children = (summand,) self.multiindex = multiindex @@ -891,15 +935,40 @@ def __new__(cls, array): dtype = Node.inherit_dtype_from_children(tuple(array.flat)) # Handle children with shape - child_shape = array.flat[0].shape + e0 = array.flat[0] + child_shape = e0.shape assert all(elem.shape == child_shape for elem in array.flat) + # Simplify [v[multiindex, j] for j in range(n)] -> partial_indexed(v, multiindex) + if all(isinstance(elem, Indexed) for elem in array.flat): + tensor = e0.children[0] + multiindex = tuple(i for i in e0.multiindex if not isinstance(i, Integral)) + index_shape = tuple(i.extent for i in multiindex if isinstance(i, Index)) + if index_shape + array.shape + child_shape == tensor.shape: + if all(elem.children[0] == tensor for elem in array.flat[1:]): + if all(elem.multiindex == multiindex + idx for idx, elem in numpy.ndenumerate(array)): + return partial_indexed(tensor, multiindex) + + # Simplify [v[j, ...] for j in range(n)] -> v + if all(isinstance(elem, ComponentTensor) and isinstance(elem.children[0], Indexed) + for elem in array.flat): + tensor = e0.children[0].children[0] + if array.shape + child_shape == tensor.shape: + if all(elem.children[0].children[0] == tensor for elem in array.flat[1:]): + if all(elem.children[0].multiindex == idx + elem.multiindex + for idx, elem in numpy.ndenumerate(array)): + return tensor + + # Flatten nested ListTensors + if all(isinstance(elem, ListTensor) for elem in array.flat): + return ListTensor(asarray([elem.array for elem in array.flat]).reshape(array.shape + child_shape)) + if child_shape: # Destroy structure direct_array = numpy.empty(array.shape + child_shape, dtype=object) - for alpha in numpy.ndindex(array.shape): + for alpha, elem in numpy.ndenumerate(array): for beta in numpy.ndindex(child_shape): - direct_array[alpha + beta] = Indexed(array[alpha], beta) + direct_array[alpha + beta] = Indexed(elem, beta) array = direct_array # Constant folding diff --git a/gem/optimise.py b/gem/optimise.py index 289ccff7..ad91e1f5 100644 --- a/gem/optimise.py +++ b/gem/optimise.py @@ -177,7 +177,7 @@ def _constant_fold_zero(node, self): @_constant_fold_zero.register(Literal) def _constant_fold_zero_literal(node, self): - if (node.array == 0).all(): + if numpy.array_equal(node.array, 0): # All zeros, make symbolic zero return Zero(node.shape) else: @@ -663,8 +663,8 @@ def _(node, self): # Unrolling summand = self(node.children[0]) shape = tuple(index.extent for index in unroll) - unrolled = Sum(*(Indexed(ComponentTensor(summand, unroll), alpha) - for alpha in numpy.ndindex(shape))) + tensor = ComponentTensor(summand, unroll) + unrolled = Sum(*(Indexed(tensor, alpha) for alpha in numpy.ndindex(shape))) return IndexSum(unrolled, tuple(index for index in node.multiindex if index not in unroll)) else: diff --git a/test/gem/test_simplify.py b/test/gem/test_simplify.py new file mode 100644 index 00000000..9d7b52a4 --- /dev/null +++ b/test/gem/test_simplify.py @@ -0,0 +1,64 @@ +import pytest +import gem +import numpy + + +@pytest.fixture +def A(): + a = gem.Variable("a", ()) + b = gem.Variable("b", ()) + c = gem.Variable("c", ()) + d = gem.Variable("d", ()) + array = [[a, b], [c, d]] + A = gem.ListTensor(array) + return A + + +@pytest.fixture +def X(): + return gem.Variable("X", (2, 2)) + + +def test_listtensor_from_indexed(X): + k = gem.Index() + elems = [gem.Indexed(X, (k, *i)) for i in numpy.ndindex(X.shape[1:])] + tensor = gem.ListTensor(numpy.reshape(elems, X.shape[1:])) + + assert isinstance(tensor, gem.ComponentTensor) + j = tensor.multiindex + expected = gem.partial_indexed(X, (k,)) + expected = gem.ComponentTensor(gem.Indexed(expected, j), j) + assert tensor == expected + + +def test_listtensor_from_fixed_indexed(A): + elems = [gem.Indexed(A, i) for i in numpy.ndindex(A.shape)] + tensor = gem.ListTensor(numpy.reshape(elems, A.shape)) + assert tensor == A + + +def test_listtensor_from_partial_indexed(A): + elems = [gem.partial_indexed(A, i) for i in numpy.ndindex(A.shape[:1])] + tensor = gem.ListTensor(elems) + assert tensor == A + + +def test_nested_partial_indexed(A): + i, j = gem.indices(2) + B = gem.partial_indexed(gem.partial_indexed(A, (i,)), (j,)) + assert B == gem.Indexed(A, (i, j)) + + +def test_componenttensor_from_indexed(A): + i, j = gem.indices(2) + Aij = gem.Indexed(A, (i, j)) + assert A == gem.ComponentTensor(Aij, (i, j)) + + +def test_flatten_indexsum(A): + i, j = gem.indices(2) + Aij = gem.Indexed(A, (i, j)) + + result = gem.IndexSum(gem.IndexSum(Aij, (i,)), (j,)) + expected = gem.IndexSum(Aij, (i, j)) + assert result == expected