From de501c10b9adee96ec3468f62b972a16dde97085 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Fri, 8 Aug 2025 10:29:01 +0100 Subject: [PATCH 01/15] GEM: simplify indexed --- gem/gem.py | 46 +++++++++++++++++++++++++++++++++++++++------- 1 file changed, 39 insertions(+), 7 deletions(-) diff --git a/gem/gem.py b/gem/gem.py index 974556754..f468747bd 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -685,12 +685,31 @@ def __new__(cls, aggregate, multiindex): if isinstance(aggregate, Zero): return Zero(dtype=aggregate.dtype) - # All indices fixed - if all(isinstance(i, int) for i in multiindex): - if isinstance(aggregate, Constant): - return Literal(aggregate.array[multiindex], dtype=aggregate.dtype) - elif isinstance(aggregate, ListTensor): - return aggregate.array[multiindex] + # Simplify Literal and ListTensor + if isinstance(aggregate, (Constant, ListTensor)): + if all(isinstance(i, int) for i in multiindex): + # All indices fixed + sub = aggregate.array[multiindex] + return Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else sub + elif any(isinstance(i, int) for i in multiindex) and all(isinstance(i, (int, Index)) for i in multiindex): + # Some indices fixed + slices = tuple(i if isinstance(i, int) else slice(None) for i in multiindex) + sub = aggregate.array[slices] + sub = Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else ListTensor(sub) + return Indexed(sub, tuple(i for i in multiindex if not isinstance(i, int))) + + # Simplify Indexed(ComponentTensor(Indexed(C, kk), jj), ii) -> Indexed(C, ll) + if isinstance(aggregate, ComponentTensor): + B, = aggregate.children + jj = aggregate.multiindex + if isinstance(B, Indexed): + C, = B.children + kk = B.multiindex + if all(j in kk for j in jj): + ii = tuple(multiindex) + rep = dict(zip(jj, ii)) + ll = tuple(rep.get(k, k) for k in kk) + return Indexed(C, ll) self = super(Indexed, cls).__new__(cls) self.children = (aggregate,) @@ -836,6 +855,11 @@ def __new__(cls, expression, multiindex): if isinstance(expression, Zero): return Zero(shape, dtype=expression.dtype) + # Index folding + if isinstance(expression, Indexed): + if multiindex == expression.multiindex: + return expression.children[0] + self = super(ComponentTensor, cls).__new__(cls) self.children = (expression,) self.multiindex = multiindex @@ -892,9 +916,17 @@ def __new__(cls, array): dtype = Node.inherit_dtype_from_children(tuple(array.flat)) # Handle children with shape - child_shape = array.flat[0].shape + e0 = array.flat[0] + child_shape = e0.shape assert all(elem.shape == child_shape for elem in array.flat) + # Index folding + if child_shape == array.shape: + if all(isinstance(elem, Indexed) for elem in array.flat): + if all(elem.children == e0.children for elem in array.flat[1:]): + if all(elem.multiindex == idx for elem, idx in zip(array.flat, numpy.ndindex(array.shape))): + return e0.children[0] + if child_shape: # Destroy structure direct_array = numpy.empty(array.shape + child_shape, dtype=object) From 592061ea4f4e3416b3729f658cd6b2380de494a4 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Mon, 11 Aug 2025 13:30:05 +0100 Subject: [PATCH 02/15] Fixes for more complicated expressions --- gem/gem.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/gem/gem.py b/gem/gem.py index f468747bd..79bcd2e79 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -274,6 +274,7 @@ class Literal(Constant): def __new__(cls, array, dtype=None): array = asarray(array) + return super(Literal, cls).__new__(cls) def __init__(self, array, dtype=None): @@ -702,14 +703,29 @@ def __new__(cls, aggregate, multiindex): if isinstance(aggregate, ComponentTensor): B, = aggregate.children jj = aggregate.multiindex + ii = multiindex + # Avoid recursion and just attempt to simplify some common patterns + # as the result of this method is not cached. if isinstance(B, Indexed): C, = B.children kk = B.multiindex - if all(j in kk for j in jj): - ii = tuple(multiindex) + if isinstance(C, ListTensor): rep = dict(zip(jj, ii)) ll = tuple(rep.get(k, k) for k in kk) - return Indexed(C, ll) + B = Indexed(C, ll) + jj = tuple(j for j in jj if j not in kk) + ii = tuple(rep[j] for j in jj) + if len(ii) == 0: + return B + + if isinstance(B, Indexed): + C, = B.children + kk = B.multiindex + if not isinstance(C, ComponentTensor) or all(isinstance(i, Index) for i in ii): + if all(j in kk for j in jj): + rep = dict(zip(jj, ii)) + ll = tuple(rep.get(k, k) for k in kk) + return Indexed(C, ll) self = super(Indexed, cls).__new__(cls) self.children = (aggregate,) @@ -722,6 +738,7 @@ def __new__(cls, aggregate, multiindex): new_indices.append(i) elif isinstance(i, VariableIndex): new_indices.extend(i.expression.free_indices) + self.free_indices = unique(aggregate.free_indices + tuple(new_indices)) return self From 158b0ec918a2100f496d9b59bbcce5235cb24c24 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Mon, 11 Aug 2025 13:41:42 +0100 Subject: [PATCH 03/15] small change --- gem/gem.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/gem/gem.py b/gem/gem.py index 111f4bb3b..cec66af68 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -273,7 +273,6 @@ class Literal(Constant): def __new__(cls, array, dtype=None): array = asarray(array) - return super(Literal, cls).__new__(cls) def __init__(self, array, dtype=None): @@ -294,7 +293,7 @@ def is_equal(self, other): return False if self.shape != other.shape: return False - return tuple(self.array.flat) == tuple(other.array.flat) + return numpy.array_equal(self.array, other.array) def get_hash(self): return hash((type(self), self.shape, tuple(self.array.flat))) @@ -737,7 +736,6 @@ def __new__(cls, aggregate, multiindex): new_indices.append(i) elif isinstance(i, VariableIndex): new_indices.extend(i.expression.free_indices) - self.free_indices = unique(aggregate.free_indices + tuple(new_indices)) return self From d2a45843f703db6e8ceff49e43541812785d6b93 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Mon, 11 Aug 2025 17:50:56 +0100 Subject: [PATCH 04/15] More simplification --- gem/gem.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/gem/gem.py b/gem/gem.py index cec66af68..9eb00f4ac 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -273,6 +273,9 @@ class Literal(Constant): def __new__(cls, array, dtype=None): array = asarray(array) + if numpy.allclose(array, 0, 1e-14): + return Zero(array.shape) + return super(Literal, cls).__new__(cls) def __init__(self, array, dtype=None): @@ -690,6 +693,7 @@ def __new__(cls, aggregate, multiindex): # All indices fixed sub = aggregate.array[multiindex] return Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else sub + elif any(isinstance(i, int) for i in multiindex) and all(isinstance(i, (int, Index)) for i in multiindex): # Some indices fixed slices = tuple(i if isinstance(i, int) else slice(None) for i in multiindex) @@ -719,12 +723,20 @@ def __new__(cls, aggregate, multiindex): if isinstance(B, Indexed): C, = B.children kk = B.multiindex - if not isinstance(C, ComponentTensor) or all(isinstance(i, Index) for i in ii): - if all(j in kk for j in jj): - rep = dict(zip(jj, ii)) - ll = tuple(rep.get(k, k) for k in kk) + if all(j in kk for j in jj): + rep = dict(zip(jj, ii)) + ll = tuple(rep.get(k, k) for k in kk) + if isinstance(C, ComponentTensor): + if (all(isinstance(i, Index) for i in ii) + or all(isinstance(l, Integral) or (l in C.multiindex) for l in ll)): + return Indexed(C, ll) + else: return Indexed(C, ll) + if len(ii) < len(multiindex): + aggregate = ComponentTensor(B, jj) + multiindex = ii + self = super(Indexed, cls).__new__(cls) self.children = (aggregate,) self.multiindex = multiindex From b49eb5cfd8bec337efbb323277df7f74ca4cfff5 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Tue, 12 Aug 2025 12:19:35 +0100 Subject: [PATCH 05/15] Do not replace free indices --- gem/gem.py | 27 ++++++++------------------- 1 file changed, 8 insertions(+), 19 deletions(-) diff --git a/gem/gem.py b/gem/gem.py index 9eb00f4ac..2ee2fb87f 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -97,7 +97,7 @@ def __radd__(self, other): def __sub__(self, other): return componentwise( Sum, self, - componentwise(Product, Literal(-1), as_gem(other))) + componentwise(Product, minus, as_gem(other))) def __rsub__(self, other): return as_gem(other).__sub__(self) @@ -273,9 +273,6 @@ class Literal(Constant): def __new__(cls, array, dtype=None): array = asarray(array) - if numpy.allclose(array, 0, 1e-14): - return Zero(array.shape) - return super(Literal, cls).__new__(cls) def __init__(self, array, dtype=None): @@ -706,36 +703,27 @@ def __new__(cls, aggregate, multiindex): B, = aggregate.children jj = aggregate.multiindex ii = multiindex - # Avoid recursion and just attempt to simplify some common patterns - # as the result of this method is not cached. + if isinstance(B, Indexed): C, = B.children kk = B.multiindex - if isinstance(C, ListTensor): + if not isinstance(C, ComponentTensor): rep = dict(zip(jj, ii)) ll = tuple(rep.get(k, k) for k in kk) B = Indexed(C, ll) jj = tuple(j for j in jj if j not in kk) ii = tuple(rep[j] for j in jj) - if len(ii) == 0: + if not ii: return B if isinstance(B, Indexed): C, = B.children kk = B.multiindex - if all(j in kk for j in jj): + ff = C.free_indices + if all((j in kk) and (j not in ff) for j in jj): rep = dict(zip(jj, ii)) ll = tuple(rep.get(k, k) for k in kk) - if isinstance(C, ComponentTensor): - if (all(isinstance(i, Index) for i in ii) - or all(isinstance(l, Integral) or (l in C.multiindex) for l in ll)): - return Indexed(C, ll) - else: - return Indexed(C, ll) - - if len(ii) < len(multiindex): - aggregate = ComponentTensor(B, jj) - multiindex = ii + return Indexed(C, ll) self = super(Indexed, cls).__new__(cls) self.children = (aggregate,) @@ -1277,6 +1265,7 @@ def view(expression, *slices): # Static one object for quicker constant folding one = Literal(1) +minus = Literal(-1) # Syntax sugar From 340ec400181d3fd1ae8469961d940488ee456ef0 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Tue, 12 Aug 2025 22:22:28 +0100 Subject: [PATCH 06/15] Simplify IndexSum --- gem/gem.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/gem/gem.py b/gem/gem.py index 2ee2fb87f..23e91ceec 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -896,6 +896,25 @@ def __new__(cls, summand, multiindex): if isinstance(summand, Zero): return summand + # No indices case + multiindex = tuple(multiindex) + if not multiindex: + return summand + + # Flatten nested sums + if isinstance(summand, IndexSum): + A, = summand.children + return IndexSum(A, summand.multiindex + multiindex) + + # Factor out common factors + if isinstance(summand, Product): + a, b = summand.children + if all(i not in a.free_indices for i in multiindex): + return Product(a, IndexSum(b, multiindex)) + + if all(i not in b.free_indices for i in multiindex): + return Product(IndexSum(a, multiindex), b) + # Unroll singleton sums unroll = tuple(index for index in multiindex if index.extent <= 1) if unroll: @@ -905,11 +924,6 @@ def __new__(cls, summand, multiindex): multiindex = tuple(index for index in multiindex if index not in unroll) - # No indices case - multiindex = tuple(multiindex) - if not multiindex: - return summand - self = super(IndexSum, cls).__new__(cls) self.children = (summand,) self.multiindex = multiindex From a7fda022d2d0470cefcd66cd65d209cc5297853b Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 14 Aug 2025 10:43:05 +0100 Subject: [PATCH 07/15] Refactor IndexSum unrolling --- gem/optimise.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gem/optimise.py b/gem/optimise.py index 3c3c9bed7..69fbb8ce6 100644 --- a/gem/optimise.py +++ b/gem/optimise.py @@ -657,8 +657,8 @@ def _(node, self): # Unrolling summand = self(node.children[0]) shape = tuple(index.extent for index in unroll) - unrolled = Sum(*(Indexed(ComponentTensor(summand, unroll), alpha) - for alpha in numpy.ndindex(shape))) + tensor = ComponentTensor(summand, unroll) + unrolled = Sum(*(Indexed(tensor, alpha) for alpha in numpy.ndindex(shape))) return IndexSum(unrolled, tuple(index for index in node.multiindex if index not in unroll)) else: From eab0d90f9947ef1bace5ebc6e2906e6b2707f0d3 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 14 Aug 2025 10:46:08 +0100 Subject: [PATCH 08/15] Flatten nested ComponentTensors --- gem/gem.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/gem/gem.py b/gem/gem.py index 23e91ceec..a2c8fe7f5 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -874,6 +874,11 @@ def __new__(cls, expression, multiindex): if multiindex == expression.multiindex: return expression.children[0] + # Flatten nested ComponentTensors + if isinstance(expression, ComponentTensor): + A, = expression.children + return ComponentTensor(A, expression.multiindex + multiindex) + self = super(ComponentTensor, cls).__new__(cls) self.children = (expression,) self.multiindex = multiindex From 41981ec30234a8d05cb296931a01553caf54042b Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 14 Aug 2025 16:32:01 +0100 Subject: [PATCH 09/15] use numpy.array_equal --- gem/optimise.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gem/optimise.py b/gem/optimise.py index 69fbb8ce6..1aa2016b8 100644 --- a/gem/optimise.py +++ b/gem/optimise.py @@ -177,7 +177,7 @@ def _constant_fold_zero(node, self): @_constant_fold_zero.register(Literal) def _constant_fold_zero_literal(node, self): - if (node.array == 0).all(): + if numpy.array_equal(node.array, 0): # All zeros, make symbolic zero return Zero(node.shape) else: From 7f580284a69e8cf32e2de3856c3415ad5059a5f7 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 14 Aug 2025 17:36:52 +0100 Subject: [PATCH 10/15] Simplify ListTensor(ComponentTensor(Indexed(...))) --- gem/gem.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/gem/gem.py b/gem/gem.py index a2c8fe7f5..1eb591a4e 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -691,13 +691,6 @@ def __new__(cls, aggregate, multiindex): sub = aggregate.array[multiindex] return Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else sub - elif any(isinstance(i, int) for i in multiindex) and all(isinstance(i, (int, Index)) for i in multiindex): - # Some indices fixed - slices = tuple(i if isinstance(i, int) else slice(None) for i in multiindex) - sub = aggregate.array[slices] - sub = Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else ListTensor(sub) - return Indexed(sub, tuple(i for i in multiindex if not isinstance(i, int))) - # Simplify Indexed(ComponentTensor(Indexed(C, kk), jj), ii) -> Indexed(C, ll) if isinstance(aggregate, ComponentTensor): B, = aggregate.children @@ -953,12 +946,23 @@ def __new__(cls, array): child_shape = e0.shape assert all(elem.shape == child_shape for elem in array.flat) - # Index folding - if child_shape == array.shape: - if all(isinstance(elem, Indexed) for elem in array.flat): - if all(elem.children == e0.children for elem in array.flat[1:]): + # Simplify [v[j] for j in range(n)] -> v + if all(isinstance(elem, Indexed) for elem in array.flat): + tensor = e0.children[0] + if array.shape + child_shape == tensor.shape: + if all(elem.children[0] == tensor for elem in array.flat[1:]): if all(elem.multiindex == idx for elem, idx in zip(array.flat, numpy.ndindex(array.shape))): - return e0.children[0] + return tensor + + # Simplify [v[j, :] for j in range(n)] -> v + if all(isinstance(elem, ComponentTensor) and isinstance(elem.children[0], Indexed) + for elem in array.flat): + tensor = e0.children[0].children[0] + if array.shape + child_shape == tensor.shape: + if all(elem.children[0].children[0] == tensor for elem in array.flat[1:]): + if all(elem.children[0].multiindex == idx + elem.multiindex + for idx, elem in zip(numpy.ndindex(array.shape), array.flat)): + return tensor if child_shape: # Destroy structure From 7077cf9514cb971e67e27cf2c299ee1b04d47534 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 14 Aug 2025 17:56:27 +0100 Subject: [PATCH 11/15] Some indices fixed --- gem/gem.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/gem/gem.py b/gem/gem.py index 1eb591a4e..2963161ea 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -691,6 +691,13 @@ def __new__(cls, aggregate, multiindex): sub = aggregate.array[multiindex] return Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else sub + elif any(isinstance(i, int) for i in multiindex) and all(isinstance(i, (int, Index)) for i in multiindex): + # Some indices fixed + slices = tuple(i if isinstance(i, int) else slice(None) for i in multiindex) + sub = aggregate.array[slices] + sub = Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else ListTensor(sub) + return Indexed(sub, tuple(i for i in multiindex if not isinstance(i, int))) + # Simplify Indexed(ComponentTensor(Indexed(C, kk), jj), ii) -> Indexed(C, ll) if isinstance(aggregate, ComponentTensor): B, = aggregate.children From 55bb3136bc70305e9d4a022e465c00c57c362a05 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 14 Aug 2025 23:39:22 +0100 Subject: [PATCH 12/15] style --- gem/gem.py | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/gem/gem.py b/gem/gem.py index 2963161ea..1b94f09cc 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -901,6 +901,15 @@ def __new__(cls, summand, multiindex): if isinstance(summand, Zero): return summand + # Unroll singleton sums + unroll = tuple(index for index in multiindex if index.extent <= 1) + if unroll: + assert numpy.prod([index.extent for index in unroll]) == 1 + summand = Indexed(ComponentTensor(summand, unroll), + (0,) * len(unroll)) + multiindex = tuple(index for index in multiindex + if index not in unroll) + # No indices case multiindex = tuple(multiindex) if not multiindex: @@ -920,15 +929,6 @@ def __new__(cls, summand, multiindex): if all(i not in b.free_indices for i in multiindex): return Product(IndexSum(a, multiindex), b) - # Unroll singleton sums - unroll = tuple(index for index in multiindex if index.extent <= 1) - if unroll: - assert numpy.prod([index.extent for index in unroll]) == 1 - summand = Indexed(ComponentTensor(summand, unroll), - (0,) * len(unroll)) - multiindex = tuple(index for index in multiindex - if index not in unroll) - self = super(IndexSum, cls).__new__(cls) self.children = (summand,) self.multiindex = multiindex @@ -958,7 +958,7 @@ def __new__(cls, array): tensor = e0.children[0] if array.shape + child_shape == tensor.shape: if all(elem.children[0] == tensor for elem in array.flat[1:]): - if all(elem.multiindex == idx for elem, idx in zip(array.flat, numpy.ndindex(array.shape))): + if all(elem.multiindex == idx for idx, elem in numpy.ndenumerate(array)): return tensor # Simplify [v[j, :] for j in range(n)] -> v @@ -968,15 +968,19 @@ def __new__(cls, array): if array.shape + child_shape == tensor.shape: if all(elem.children[0].children[0] == tensor for elem in array.flat[1:]): if all(elem.children[0].multiindex == idx + elem.multiindex - for idx, elem in zip(numpy.ndindex(array.shape), array.flat)): + for idx, elem in numpy.ndenumerate(array)): return tensor + # Flatten nested ListTensors + if all(isinstance(elem, ListTensor) for elem in array.flat): + return ListTensor(asarray([elem.array for elem in array.flat]).reshape(array.shape + child_shape)) + if child_shape: # Destroy structure direct_array = numpy.empty(array.shape + child_shape, dtype=object) - for alpha in numpy.ndindex(array.shape): + for alpha, elem in numpy.ndenumerate(array): for beta in numpy.ndindex(child_shape): - direct_array[alpha + beta] = Indexed(array[alpha], beta) + direct_array[alpha + beta] = Indexed(elem, beta) array = direct_array # Constant folding From 10a8604ea94b92d6eb87b502279dc14f978f2ae2 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Tue, 26 Aug 2025 17:52:12 +0100 Subject: [PATCH 13/15] Add tests --- gem/gem.py | 17 +------------- test/gem/test_simplify.py | 47 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 16 deletions(-) create mode 100644 test/gem/test_simplify.py diff --git a/gem/gem.py b/gem/gem.py index 1b94f09cc..2638effe4 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -97,7 +97,7 @@ def __radd__(self, other): def __sub__(self, other): return componentwise( Sum, self, - componentwise(Product, minus, as_gem(other))) + componentwise(Product, Literal(-1), as_gem(other))) def __rsub__(self, other): return as_gem(other).__sub__(self) @@ -874,11 +874,6 @@ def __new__(cls, expression, multiindex): if multiindex == expression.multiindex: return expression.children[0] - # Flatten nested ComponentTensors - if isinstance(expression, ComponentTensor): - A, = expression.children - return ComponentTensor(A, expression.multiindex + multiindex) - self = super(ComponentTensor, cls).__new__(cls) self.children = (expression,) self.multiindex = multiindex @@ -920,15 +915,6 @@ def __new__(cls, summand, multiindex): A, = summand.children return IndexSum(A, summand.multiindex + multiindex) - # Factor out common factors - if isinstance(summand, Product): - a, b = summand.children - if all(i not in a.free_indices for i in multiindex): - return Product(a, IndexSum(b, multiindex)) - - if all(i not in b.free_indices for i in multiindex): - return Product(IndexSum(a, multiindex), b) - self = super(IndexSum, cls).__new__(cls) self.children = (summand,) self.multiindex = multiindex @@ -1299,7 +1285,6 @@ def view(expression, *slices): # Static one object for quicker constant folding one = Literal(1) -minus = Literal(-1) # Syntax sugar diff --git a/test/gem/test_simplify.py b/test/gem/test_simplify.py new file mode 100644 index 000000000..0658a7f56 --- /dev/null +++ b/test/gem/test_simplify.py @@ -0,0 +1,47 @@ +import pytest +import gem +import numpy + + +@pytest.fixture +def A(): + a = gem.Variable("a", ()) + b = gem.Variable("b", ()) + c = gem.Variable("c", ()) + d = gem.Variable("d", ()) + array = [[a, b], [c, d]] + A = gem.ListTensor(array) + return A + + +def test_listtensor_from_indexed(A): + elems = [gem.Indexed(A, i) for i in numpy.ndindex(A.shape)] + tensor = gem.ListTensor(numpy.reshape(elems, A.shape)) + assert tensor == A + + +def test_listtensor_from_partial_indexed(A): + elems = [gem.partial_indexed(A, i) for i in numpy.ndindex(A.shape[:1])] + tensor = gem.ListTensor(elems) + assert tensor == A + + +def test_nested_partial_indexed(A): + i, j = gem.indices(2) + B = gem.partial_indexed(gem.partial_indexed(A, (i,)), (j,)) + assert B == gem.Indexed(A, (i, j)) + + +def test_componenttensor_from_indexed(A): + i, j = gem.indices(2) + Aij = gem.Indexed(A, (i, j)) + assert A == gem.ComponentTensor(Aij, (i, j)) + + +def test_flatten_indexsum(A): + i, j = gem.indices(2) + Aij = gem.Indexed(A, (i, j)) + + result = gem.IndexSum(gem.IndexSum(Aij, (i,)), (j,)) + expected = gem.IndexSum(Aij, (i, j)) + assert result == expected From b738ec64f0eb8e7ac9197fd255b27b4bc25722c6 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Wed, 27 Aug 2025 11:34:03 +0100 Subject: [PATCH 14/15] More simplify --- gem/gem.py | 12 +++++++----- test/gem/test_simplify.py | 19 ++++++++++++++++++- 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/gem/gem.py b/gem/gem.py index 2638effe4..c10198a7e 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -939,15 +939,17 @@ def __new__(cls, array): child_shape = e0.shape assert all(elem.shape == child_shape for elem in array.flat) - # Simplify [v[j] for j in range(n)] -> v + # Simplify [v[multiindex, j] for j in range(n)] -> partial_indexed(v, multiindex) if all(isinstance(elem, Indexed) for elem in array.flat): tensor = e0.children[0] - if array.shape + child_shape == tensor.shape: + multiindex = tuple(i for i in e0.multiindex if not isinstance(i, Integral)) + index_shape = tuple(i.extent for i in multiindex) + if index_shape + array.shape + child_shape == tensor.shape: if all(elem.children[0] == tensor for elem in array.flat[1:]): - if all(elem.multiindex == idx for idx, elem in numpy.ndenumerate(array)): - return tensor + if all(elem.multiindex == multiindex + idx for idx, elem in numpy.ndenumerate(array)): + return partial_indexed(tensor, multiindex) - # Simplify [v[j, :] for j in range(n)] -> v + # Simplify [v[j, ...] for j in range(n)] -> v if all(isinstance(elem, ComponentTensor) and isinstance(elem.children[0], Indexed) for elem in array.flat): tensor = e0.children[0].children[0] diff --git a/test/gem/test_simplify.py b/test/gem/test_simplify.py index 0658a7f56..9d7b52a49 100644 --- a/test/gem/test_simplify.py +++ b/test/gem/test_simplify.py @@ -14,7 +14,24 @@ def A(): return A -def test_listtensor_from_indexed(A): +@pytest.fixture +def X(): + return gem.Variable("X", (2, 2)) + + +def test_listtensor_from_indexed(X): + k = gem.Index() + elems = [gem.Indexed(X, (k, *i)) for i in numpy.ndindex(X.shape[1:])] + tensor = gem.ListTensor(numpy.reshape(elems, X.shape[1:])) + + assert isinstance(tensor, gem.ComponentTensor) + j = tensor.multiindex + expected = gem.partial_indexed(X, (k,)) + expected = gem.ComponentTensor(gem.Indexed(expected, j), j) + assert tensor == expected + + +def test_listtensor_from_fixed_indexed(A): elems = [gem.Indexed(A, i) for i in numpy.ndindex(A.shape)] tensor = gem.ListTensor(numpy.reshape(elems, A.shape)) assert tensor == A From 5dfd17de4df9e3685b094877b3f31dee9563a2c0 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Fri, 29 Aug 2025 14:10:30 +0100 Subject: [PATCH 15/15] Fix up --- gem/gem.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gem/gem.py b/gem/gem.py index 463a9111b..8601c1c20 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -943,7 +943,7 @@ def __new__(cls, array): if all(isinstance(elem, Indexed) for elem in array.flat): tensor = e0.children[0] multiindex = tuple(i for i in e0.multiindex if not isinstance(i, Integral)) - index_shape = tuple(i.extent for i in multiindex) + index_shape = tuple(i.extent for i in multiindex if isinstance(i, Index)) if index_shape + array.shape + child_shape == tensor.shape: if all(elem.children[0] == tensor for elem in array.flat[1:]): if all(elem.multiindex == multiindex + idx for idx, elem in numpy.ndenumerate(array)):