Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 79 additions & 10 deletions gem/gem.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def is_equal(self, other):
return False
if self.shape != other.shape:
return False
return tuple(self.array.flat) == tuple(other.array.flat)
return numpy.array_equal(self.array, other.array)

def get_hash(self):
return hash((type(self), self.shape, tuple(self.array.flat)))
Expand Down Expand Up @@ -684,12 +684,46 @@ def __new__(cls, aggregate, multiindex):
if isinstance(aggregate, Zero):
return Zero(dtype=aggregate.dtype)

# All indices fixed
if all(isinstance(i, int) for i in multiindex):
if isinstance(aggregate, Constant):
return Literal(aggregate.array[multiindex], dtype=aggregate.dtype)
elif isinstance(aggregate, ListTensor):
return aggregate.array[multiindex]
# Simplify Literal and ListTensor
if isinstance(aggregate, (Constant, ListTensor)):
if all(isinstance(i, int) for i in multiindex):
# All indices fixed
sub = aggregate.array[multiindex]
return Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else sub

elif any(isinstance(i, int) for i in multiindex) and all(isinstance(i, (int, Index)) for i in multiindex):
# Some indices fixed
slices = tuple(i if isinstance(i, int) else slice(None) for i in multiindex)
sub = aggregate.array[slices]
sub = Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else ListTensor(sub)
return Indexed(sub, tuple(i for i in multiindex if not isinstance(i, int)))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if this is safe. This is a recursion, and, unlike when we use DAGTraverser, the result is not cached.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the issue? Are we potentially creating new objects and not freeing memory for the old ones?


# Simplify Indexed(ComponentTensor(Indexed(C, kk), jj), ii) -> Indexed(C, ll)
if isinstance(aggregate, ComponentTensor):
B, = aggregate.children
jj = aggregate.multiindex
ii = multiindex

if isinstance(B, Indexed):
C, = B.children
kk = B.multiindex
if not isinstance(C, ComponentTensor):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to avoid ComponentTensor here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is because for ComponentTensor we cannot replace indices if the substitution involves the free indices, for other classes this seems fine.

rep = dict(zip(jj, ii))
ll = tuple(rep.get(k, k) for k in kk)
B = Indexed(C, ll)
jj = tuple(j for j in jj if j not in kk)
ii = tuple(rep[j] for j in jj)
if not ii:
return B

if isinstance(B, Indexed):
C, = B.children
kk = B.multiindex
ff = C.free_indices
if all((j in kk) and (j not in ff) for j in jj):
rep = dict(zip(jj, ii))
ll = tuple(rep.get(k, k) for k in kk)
return Indexed(C, ll)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also causes recursion.


self = super(Indexed, cls).__new__(cls)
self.children = (aggregate,)
Expand Down Expand Up @@ -835,6 +869,11 @@ def __new__(cls, expression, multiindex):
if isinstance(expression, Zero):
return Zero(shape, dtype=expression.dtype)

# Index folding
if isinstance(expression, Indexed):
if multiindex == expression.multiindex:
return expression.children[0]

self = super(ComponentTensor, cls).__new__(cls)
self.children = (expression,)
self.multiindex = multiindex
Expand Down Expand Up @@ -871,6 +910,11 @@ def __new__(cls, summand, multiindex):
if not multiindex:
return summand

# Flatten nested sums
if isinstance(summand, IndexSum):
A, = summand.children
return IndexSum(A, summand.multiindex + multiindex)

self = super(IndexSum, cls).__new__(cls)
self.children = (summand,)
self.multiindex = multiindex
Expand All @@ -891,15 +935,40 @@ def __new__(cls, array):
dtype = Node.inherit_dtype_from_children(tuple(array.flat))

# Handle children with shape
child_shape = array.flat[0].shape
e0 = array.flat[0]
child_shape = e0.shape
assert all(elem.shape == child_shape for elem in array.flat)

# Simplify [v[multiindex, j] for j in range(n)] -> partial_indexed(v, multiindex)
if all(isinstance(elem, Indexed) for elem in array.flat):
tensor = e0.children[0]
multiindex = tuple(i for i in e0.multiindex if not isinstance(i, Integral))
index_shape = tuple(i.extent for i in multiindex if isinstance(i, Index))
if index_shape + array.shape + child_shape == tensor.shape:
if all(elem.children[0] == tensor for elem in array.flat[1:]):
if all(elem.multiindex == multiindex + idx for idx, elem in numpy.ndenumerate(array)):
return partial_indexed(tensor, multiindex)
Comment on lines +943 to +950

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should make it more explicit that we are only handling multiindex of the following pattern: (Index, Index, Index, ..., Integral, Integral, ...). It looks the case where we have VariableIndexs, for instance, is handled in a very obscure way.


# Simplify [v[j, ...] for j in range(n)] -> v
if all(isinstance(elem, ComponentTensor) and isinstance(elem.children[0], Indexed)
for elem in array.flat):
tensor = e0.children[0].children[0]
if array.shape + child_shape == tensor.shape:
if all(elem.children[0].children[0] == tensor for elem in array.flat[1:]):
if all(elem.children[0].multiindex == idx + elem.multiindex
for idx, elem in numpy.ndenumerate(array)):
return tensor

# Flatten nested ListTensors
if all(isinstance(elem, ListTensor) for elem in array.flat):
return ListTensor(asarray([elem.array for elem in array.flat]).reshape(array.shape + child_shape))

if child_shape:
# Destroy structure
direct_array = numpy.empty(array.shape + child_shape, dtype=object)
for alpha in numpy.ndindex(array.shape):
for alpha, elem in numpy.ndenumerate(array):
for beta in numpy.ndindex(child_shape):
direct_array[alpha + beta] = Indexed(array[alpha], beta)
direct_array[alpha + beta] = Indexed(elem, beta)
array = direct_array

# Constant folding
Expand Down
6 changes: 3 additions & 3 deletions gem/optimise.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def _constant_fold_zero(node, self):

@_constant_fold_zero.register(Literal)
def _constant_fold_zero_literal(node, self):
if (node.array == 0).all():
if numpy.array_equal(node.array, 0):
# All zeros, make symbolic zero
return Zero(node.shape)
else:
Expand Down Expand Up @@ -663,8 +663,8 @@ def _(node, self):
# Unrolling
summand = self(node.children[0])
shape = tuple(index.extent for index in unroll)
unrolled = Sum(*(Indexed(ComponentTensor(summand, unroll), alpha)
for alpha in numpy.ndindex(shape)))
tensor = ComponentTensor(summand, unroll)
unrolled = Sum(*(Indexed(tensor, alpha) for alpha in numpy.ndindex(shape)))
return IndexSum(unrolled, tuple(index for index in node.multiindex
if index not in unroll))
else:
Expand Down
64 changes: 64 additions & 0 deletions test/gem/test_simplify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import pytest
import gem
import numpy


@pytest.fixture
def A():
a = gem.Variable("a", ())
b = gem.Variable("b", ())
c = gem.Variable("c", ())
d = gem.Variable("d", ())
array = [[a, b], [c, d]]
A = gem.ListTensor(array)
return A


@pytest.fixture
def X():
return gem.Variable("X", (2, 2))


def test_listtensor_from_indexed(X):
k = gem.Index()
elems = [gem.Indexed(X, (k, *i)) for i in numpy.ndindex(X.shape[1:])]
tensor = gem.ListTensor(numpy.reshape(elems, X.shape[1:]))

assert isinstance(tensor, gem.ComponentTensor)
j = tensor.multiindex
expected = gem.partial_indexed(X, (k,))
expected = gem.ComponentTensor(gem.Indexed(expected, j), j)
assert tensor == expected


def test_listtensor_from_fixed_indexed(A):
elems = [gem.Indexed(A, i) for i in numpy.ndindex(A.shape)]
tensor = gem.ListTensor(numpy.reshape(elems, A.shape))
assert tensor == A


def test_listtensor_from_partial_indexed(A):
elems = [gem.partial_indexed(A, i) for i in numpy.ndindex(A.shape[:1])]
tensor = gem.ListTensor(elems)
assert tensor == A


def test_nested_partial_indexed(A):
i, j = gem.indices(2)
B = gem.partial_indexed(gem.partial_indexed(A, (i,)), (j,))
assert B == gem.Indexed(A, (i, j))


def test_componenttensor_from_indexed(A):
i, j = gem.indices(2)
Aij = gem.Indexed(A, (i, j))
assert A == gem.ComponentTensor(Aij, (i, j))


def test_flatten_indexsum(A):
i, j = gem.indices(2)
Aij = gem.Indexed(A, (i, j))

result = gem.IndexSum(gem.IndexSum(Aij, (i,)), (j,))
expected = gem.IndexSum(Aij, (i, j))
assert result == expected