diff --git a/gem/gem.py b/gem/gem.py index 57014545..81694426 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -18,6 +18,7 @@ from itertools import chain from operator import attrgetter from numbers import Integral, Number +import collections import numpy from numpy import asarray @@ -33,8 +34,21 @@ 'IndexSum', 'ListTensor', 'Concatenate', 'Delta', 'index_sum', 'partial_indexed', 'reshape', 'view', 'indices', 'as_gem', 'FlexiblyIndexed', - 'Inverse', 'Solve'] - + 'Inverse', 'Solve', 'Action', 'MatfreeSolveContext'] + + +# Defaults are in the order of items in MatfreeSolveContext +MatfreeSolveContext = collections.namedtuple("MatfreeSolveContext", + ["matfree", + "Aonx", + "Aonp", + "preconditioner", + "Ponr", + "diag_prec", + "rtol", + "atol", + "max_it"]) +DEFAULT_MSC = MatfreeSolveContext(*(False, None, None, None, None, False, "1.e-12", "1.e-70", None)) class NodeMeta(type): """Metaclass of GEM nodes. @@ -249,9 +263,11 @@ class Variable(Terminal): __slots__ = ('name', 'shape') __front__ = ('name', 'shape') + id = 0 def __init__(self, name, shape): - self.name = name + Variable.id += 1 + self.name = "T%d" % Variable.id if not name else name self.shape = shape @@ -834,18 +850,62 @@ class Solve(Node): Represents the X obtained by solving AX = B. """ - __slots__ = ('children', 'shape') + __slots__ = ('children', 'shape', 'name', 'ctx') + __back__ = ('name', 'ctx') - def __init__(self, A, B): + id = 0 + + def __new__(cls, A, B, name="", ctx=DEFAULT_MSC): # Shape requirements assert B.shape assert len(A.shape) == 2 assert A.shape[0] == A.shape[1] assert A.shape[0] == B.shape[0] + self = super(Solve, cls).__new__(cls) self.children = (A, B) self.shape = A.shape[1:] + B.shape[1:] + # We use a ctx rather than kwargs because then __slots__ and __back__ are independent + # of the extra arguments passed for the matrix-free, iterative solver + # Values in default args are overwritten if there is a corresponding kwarg + # It's not save to make defaults a nested dict + updated_ctx = DEFAULT_MSC._asdict().copy() + updated_ctx.update(ctx._asdict()) + self.ctx = MatfreeSolveContext(**updated_ctx) + + + # When nodes are reconstructed in the GEM optimiser, + # we want them to keep their names which is why + # there is an optional name keyword in this constructor + self.name = name if name else "S%d" % Solve.id + Solve.id += 1 + return self + + +class Action(Node): + __slots__ = ('children', 'shape', 'pick_op', 'name') + __back__ = ('pick_op', 'name') + id = 0 + + def __new__(cls, A, B, pick_op, name=""): + assert B.shape + assert len(A.shape) == 2 + assert A.shape[pick_op] == B.shape[0] + assert pick_op < 2 + + self = super(Action, cls).__new__(cls) + self.children = A, B + self.shape = (A.shape[pick_op ^ 1],) + self.pick_op = pick_op + + # When nodes are reconstructed in the GEM optimiser, + # we want them to keep their names which is why + # there is an optional name keyword in this constructor + self.name = name if name else "A%d" % Action.id + Action.id += 1 + return self + def unique(indices): """Sorts free indices and eliminates duplicates. @@ -903,7 +963,7 @@ def strides_of(shape): def decompose_variable_view(expression): """Extract information from a shaped node. Decompose ComponentTensor + FlexiblyIndexed.""" - if (isinstance(expression, (Variable, Inverse, Solve))): + if (isinstance(expression, (Variable, Inverse, Solve, Action))): variable = expression indexes = tuple(Index(extent=extent) for extent in expression.shape) dim2idxs = tuple((0, ((index, 1),)) for index in indexes) diff --git a/tsfc/fem.py b/tsfc/fem.py index 2c5c865c..555a8a93 100644 --- a/tsfc/fem.py +++ b/tsfc/fem.py @@ -608,7 +608,15 @@ def fiat_to_ufl(fiat_dict, order): @translate.register(Argument) def translate_argument(terminal, mt, ctx): - argument_multiindex = ctx.argument_multiindices[terminal.number()] + # The following try except is need due to introducing an optimiser and actions in Slate. + # When action(Transpose(Tensor(form)), AssembledVector(f)) is part of an expression, + # it gets translated into a form where the first argument is replaced by the coefficient f. + # FIXME This is ugly, maybe we can introduce new information on the ctx to make it cleaner. + no = terminal.number() + try: + argument_multiindex = ctx.argument_multiindices[no] + except IndexError: + argument_multiindex = ctx.argument_multiindices[0] sigma = tuple(gem.Index(extent=d) for d in mt.expr.ufl_shape) element = ctx.create_element(terminal.ufl_element(), restriction=mt.restriction) diff --git a/tsfc/loopy.py b/tsfc/loopy.py index ae8f93e5..c6dd2462 100644 --- a/tsfc/loopy.py +++ b/tsfc/loopy.py @@ -113,13 +113,14 @@ def assign_dtypes(expressions, scalar_type): class LoopyContext(object): - def __init__(self, target=None): + def __init__(self, kernel_name, target=None): self.indices = {} # indices for declarations and referencing values, from ImperoC self.active_indices = {} # gem index -> pymbolic variable self.index_extent = OrderedDict() # pymbolic variable for indices -> extent self.gem_to_pymbolic = {} # gem node -> pymbolic variable self.name_gen = UniqueNameGenerator() self.target = target + self.kernel_name = kernel_name def fetch_multiindex(self, multiindex): indices = [] @@ -199,7 +200,7 @@ def active_indices(mapping, ctx): def generate(impero_c, args, scalar_type, kernel_name="loopy_kernel", index_names=[], - return_increments=True, log=False): + return_increments=True, return_ctx=False, log=False): """Generates loopy code. :arg impero_c: ImperoC tuple with Impero AST and other data @@ -209,9 +210,10 @@ def generate(impero_c, args, scalar_type, kernel_name="loopy_kernel", index_name :arg index_names: pre-assigned index names :arg return_increments: Does codegen for Return nodes increment the lvalue, or assign? :arg log: bool if the Kernel should be profiled with Log events + :arg return_ctx: Is the ctx returned alongside the generated kernel? :returns: loopy kernel """ - ctx = LoopyContext(target=target) + ctx = LoopyContext(kernel_name, target=target) ctx.indices = impero_c.indices ctx.index_names = defaultdict(lambda: "i", index_names) ctx.epsilon = numpy.finfo(scalar_type).resolution @@ -221,7 +223,7 @@ def generate(impero_c, args, scalar_type, kernel_name="loopy_kernel", index_name # Create arguments data = list(args) for i, (temp, dtype) in enumerate(assign_dtypes(impero_c.temporaries, scalar_type)): - name = "t%d" % i + name = temp.name if hasattr(temp, "name") and temp.shape else "t%d" % i if isinstance(temp, gem.Constant): data.append(lp.TemporaryVariable(name, shape=temp.shape, dtype=dtype, initializer=temp.array, address_space=lp.AddressSpace.LOCAL, read_only=True)) else: @@ -240,13 +242,14 @@ def generate(impero_c, args, scalar_type, kernel_name="loopy_kernel", index_name # Create loopy kernel knl = lp.make_function(domains, instructions, data, name=kernel_name, target=target, - seq_dependencies=True, silenced_warnings=["summing_if_branches_ops"], + seq_dependencies=True, silenced_warnings=["summing_if_branches_ops", "single_writer_after_creation", + "unused_inames", "insn_count_subgroups_upper_bound"], lang_version=(2018, 2), preambles=preamble) # Prevent loopy interchange by loopy knl = lp.prioritize_loops(knl, ",".join(ctx.index_extent.keys())) - return knl, event_name + return (knl, ctx, event_name) if return_ctx else (knl, event_name) def create_domains(indices): @@ -258,7 +261,8 @@ def create_domains(indices): domains = [] for idx, extent in indices: inames = isl.make_zero_and_vars([idx]) - domains.append(((inames[0].le_set(inames[idx])) & (inames[idx].lt_set(inames[0] + extent)))) + domains.append(isl.BasicSet(str(((inames[0].le_set(inames[idx])) & + (inames[idx].lt_set(inames[0] + extent)))))) if not domains: domains = [isl.BasicSet("[] -> {[]}")] @@ -350,6 +354,22 @@ def statement_evaluate(leaf, ctx): return [lp.CallInstruction(lhs, rhs, within_inames=ctx.active_inames())] elif isinstance(expr, gem.Solve): + name = "mtf_solve" if getattr(expr.ctx, "matfree") else "solve" + idx = ctx.pymbolic_multiindex(expr.shape) + var = ctx.pymbolic_variable(expr) + lhs = (SubArrayRef(idx, p.Subscript(var, idx)),) + + reads = [] + prec = getattr(expr.ctx, "preconditioner") + childs = expr.children+(prec,) if prec else expr.children + for child in childs: + idx_reads = ctx.pymbolic_multiindex(child.shape) + var_reads = ctx.pymbolic_variable(child) + reads.append(SubArrayRef(idx_reads, p.Subscript(var_reads, idx_reads))) + rhs = p.Call(p.Variable(name), tuple(reads)) + + return [lp.CallInstruction(lhs, rhs, within_inames=ctx.active_inames())] + elif isinstance(expr, gem.Action): idx = ctx.pymbolic_multiindex(expr.shape) var = ctx.pymbolic_variable(expr) lhs = (SubArrayRef(idx, p.Subscript(var, idx)),) @@ -359,7 +379,7 @@ def statement_evaluate(leaf, ctx): idx_reads = ctx.pymbolic_multiindex(child.shape) var_reads = ctx.pymbolic_variable(child) reads.append(SubArrayRef(idx_reads, p.Subscript(var_reads, idx_reads))) - rhs = p.Call(p.Variable("solve"), tuple(reads)) + rhs = p.Call(p.Variable("action"), tuple(reads)) return [lp.CallInstruction(lhs, rhs, within_inames=ctx.active_inames())] else: