diff --git a/.github/workflows/pytest-petsc.yml b/.github/workflows/pytest-petsc.yml new file mode 100644 index 0000000000..91483e2fc4 --- /dev/null +++ b/.github/workflows/pytest-petsc.yml @@ -0,0 +1,74 @@ +name: CI-petsc + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +on: + # Trigger the workflow on push or pull request, + # but only for the master branch + push: + branches: + - master + pull_request: + branches: + - master + +jobs: + pytest: + name: ${{ matrix.name }}-${{ matrix.set }} + runs-on: "${{ matrix.os }}" + + env: + DOCKER_BUILDKIT: "1" + DEVITO_ARCH: "${{ matrix.arch }}" + DEVITO_LANGUAGE: ${{ matrix.language }} + + strategy: + # Prevent all build to stop if a single one fails + fail-fast: false + + matrix: + name: [ + pytest-docker-py39-gcc-noomp + ] + include: + - name: pytest-docker-py39-gcc-noomp + python-version: '3.9' + os: ubuntu-latest + arch: "gcc" + language: "C" + sympy: "1.12" + + steps: + - name: Checkout devito + uses: actions/checkout@v4 + + - name: Build docker image + run: | + docker build . --file docker/Dockerfile.devito --tag devito_img --build-arg base=zoeleibowitz/bases:cpu-${{ matrix.arch }} --build-arg petscinstall=petsc + + - name: Set run prefix + run: | + echo "RUN_CMD=docker run --rm -t -e CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }} --name testrun devito_img" >> $GITHUB_ENV + id: set-run + + - name: Set tests + run : | + echo "TESTS=tests/test_petsc.py" >> $GITHUB_ENV + id: set-tests + + - name: Check configuration + run: | + ${{ env.RUN_CMD }} python3 -c "from devito import configuration; print(''.join(['%s: %s \n' % (k, v) for (k, v) in configuration.items()]))" + + - name: Test with pytest + run: | + ${{ env.RUN_CMD }} mpiexec -n 1 pytest --cov --cov-config=.coveragerc --cov-report=xml ${{ env.TESTS }} + + - name: Upload coverage to Codecov + if: "!contains(matrix.name, 'docker')" + uses: codecov/codecov-action@v4 + with: + token: ${{ secrets.CODECOV_TOKEN }} + name: ${{ matrix.name }} diff --git a/conftest.py b/conftest.py index 4bb0629327..b804aaebc7 100644 --- a/conftest.py +++ b/conftest.py @@ -14,6 +14,7 @@ from devito.ir.iet import (FindNodes, FindSymbols, Iteration, ParallelBlock, retrieve_iteration_tree) from devito.tools import as_tuple +from devito.petsc.utils import get_petsc_dir, get_petsc_arch try: from mpi4py import MPI # noqa @@ -33,7 +34,7 @@ def skipif(items, whole_module=False): accepted = set() accepted.update({'device', 'device-C', 'device-openmp', 'device-openacc', 'device-aomp', 'cpu64-icc', 'cpu64-icx', 'cpu64-nvc', 'cpu64-arm', - 'cpu64-icpx', 'chkpnt'}) + 'cpu64-icpx', 'chkpnt', 'petsc'}) accepted.update({'nodevice'}) unknown = sorted(set(items) - accepted) if unknown: @@ -87,6 +88,19 @@ def skipif(items, whole_module=False): if i == 'chkpnt' and Revolver is NoopRevolver: skipit = "pyrevolve not installed" break + if i == 'petsc': + petsc_dir = get_petsc_dir() + petsc_arch = get_petsc_arch() + if petsc_dir is None or petsc_arch is None: + skipit = "PETSC_DIR or PETSC_ARCH are not set" + break + else: + petsc_installed = os.path.join( + petsc_dir, petsc_arch, 'include', 'petscconf.h' + ) + if not os.path.isfile(petsc_installed): + skipit = "PETSc is not installed" + break if skipit is False: return pytest.mark.skipif(False, reason='') diff --git a/devito/ir/equations/algorithms.py b/devito/ir/equations/algorithms.py index 946eb4324b..f699b353f0 100644 --- a/devito/ir/equations/algorithms.py +++ b/devito/ir/equations/algorithms.py @@ -170,7 +170,8 @@ def concretize_subdims(exprs, **kwargs): """ sregistry = kwargs.get('sregistry') - mapper = {} + # Update based on changes in #2509 + mapper = kwargs.get('concretize_mapper', {}) rebuilt = {} # Rebuilt implicit dims etc which are shared between dimensions _concretize_subdims(exprs, mapper, rebuilt, sregistry) diff --git a/devito/ir/equations/equation.py b/devito/ir/equations/equation.py index ada1c23f22..6b2261b51c 100644 --- a/devito/ir/equations/equation.py +++ b/devito/ir/equations/equation.py @@ -9,10 +9,12 @@ Stencil, detect_io, detect_accesses) from devito.symbolics import IntDiv, limits_mapper, uxreplace from devito.tools import Pickable, Tag, frozendict -from devito.types import Eq, Inc, ReduceMax, ReduceMin, relational_min +from devito.types import (Eq, Inc, ReduceMax, ReduceMin, + relational_min) +from devito.types.equation import PetscEq __all__ = ['LoweredEq', 'ClusterizedEq', 'DummyEq', 'OpInc', 'OpMin', 'OpMax', - 'identity_mapper'] + 'identity_mapper', 'OpPetsc'] class IREq(sympy.Eq, Pickable): @@ -102,7 +104,8 @@ def detect(cls, expr): reduction_mapper = { Inc: OpInc, ReduceMax: OpMax, - ReduceMin: OpMin + ReduceMin: OpMin, + PetscEq: OpPetsc } try: return reduction_mapper[type(expr)] @@ -119,6 +122,7 @@ def detect(cls, expr): OpInc = Operation('+') OpMax = Operation('max') OpMin = Operation('min') +OpPetsc = Operation('solve') identity_mapper = { diff --git a/devito/ir/iet/algorithms.py b/devito/ir/iet/algorithms.py index 0b57b876f7..52f48e28b1 100644 --- a/devito/ir/iet/algorithms.py +++ b/devito/ir/iet/algorithms.py @@ -3,6 +3,8 @@ from devito.ir.iet import (Expression, Increment, Iteration, List, Conditional, SyncSpot, Section, HaloSpot, ExpressionBundle) from devito.tools import timed_pass +from devito.petsc.types import MetaData +from devito.petsc.iet.utils import petsc_iet_mapper __all__ = ['iet_build'] @@ -24,6 +26,8 @@ def iet_build(stree): for e in i.exprs: if e.is_Increment: exprs.append(Increment(e)) + elif isinstance(e.rhs, MetaData): + exprs.append(petsc_iet_mapper[e.operation](e, operation=e.operation)) else: exprs.append(Expression(e, operation=e.operation)) body = ExpressionBundle(i.ispace, i.ops, i.traffic, body=exprs) diff --git a/devito/ir/iet/efunc.py b/devito/ir/iet/efunc.py index 10aa8920e6..064f02c487 100644 --- a/devito/ir/iet/efunc.py +++ b/devito/ir/iet/efunc.py @@ -1,6 +1,6 @@ from functools import cached_property -from devito.ir.iet.nodes import Call, Callable +from devito.ir.iet.nodes import Call, Callable, FixedArgsCallable from devito.ir.iet.utils import derive_parameters from devito.symbolics import uxreplace from devito.tools import as_tuple @@ -131,7 +131,7 @@ class AsyncCall(Call): pass -class ThreadCallable(Callable): +class ThreadCallable(FixedArgsCallable): """ A Callable executed asynchronously by a thread. diff --git a/devito/ir/iet/nodes.py b/devito/ir/iet/nodes.py index d45c9af939..dd6db491c0 100644 --- a/devito/ir/iet/nodes.py +++ b/devito/ir/iet/nodes.py @@ -1,6 +1,7 @@ """The Iteration/Expression Tree (IET) hierarchy.""" import abc +import ctypes import inspect from functools import cached_property from collections import OrderedDict, namedtuple @@ -19,7 +20,7 @@ ctypes_to_cstr) from devito.types.basic import (AbstractFunction, AbstractSymbol, Basic, Indexed, Symbol) -from devito.types.object import AbstractObject, LocalObject +from devito.types.object import AbstractObject, LocalObject, LocalCompositeObject __all__ = ['Node', 'MultiTraversable', 'Block', 'Expression', 'Callable', 'Call', 'ExprStmt', 'Conditional', 'Iteration', 'List', 'Section', @@ -28,7 +29,7 @@ 'Increment', 'Return', 'While', 'ListMajor', 'ParallelIteration', 'ParallelBlock', 'Dereference', 'Lambda', 'SyncSpot', 'Pragma', 'DummyExpr', 'BlankLine', 'ParallelTree', 'BusyWait', 'UsingNamespace', - 'CallableBody', 'Transfer'] + 'CallableBody', 'Transfer', 'Callback', 'FixedArgsCallable'] # First-class IET nodes @@ -763,6 +764,15 @@ def defines(self): return self.all_parameters +class FixedArgsCallable(Callable): + + """ + A Callable class that enforces a fixed function signature. + """ + + pass + + class CallableBody(MultiTraversable): """ @@ -1041,8 +1051,8 @@ class Dereference(ExprStmt, Node): The following cases are supported: * `pointer` is a PointerArray or TempFunction, and `pointee` is an Array. - * `pointer` is an ArrayObject representing a pointer to a C struct, and - `pointee` is a field in `pointer`. + * `pointer` is an ArrayObject or LocalCompositeObject representing a pointer + to a C struct, and `pointee` is a field in `pointer`. """ is_Dereference = True @@ -1061,13 +1071,15 @@ def functions(self): @property def expr_symbols(self): - ret = [self.pointer.indexed] + ret = [] if self.pointer.is_PointerArray or self.pointer.is_TempFunction: - ret.append(self.pointee.indexed) + ret.extend([self.pointer.indexed, self.pointee.indexed]) ret.extend(flatten(i.free_symbols for i in self.pointee.symbolic_shape[1:])) ret.extend(self.pointer.free_symbols) else: - ret.append(self.pointee._C_symbol) + assert (isinstance(self.pointer, LocalCompositeObject) or + issubclass(self.pointer._C_ctype, ctypes._Pointer)) + ret.extend([self.pointer._C_symbol, self.pointee._C_symbol]) return tuple(filter_ordered(ret)) @property @@ -1133,6 +1145,45 @@ def defines(self): return tuple(self.parameters) +class Callback(Call): + """ + Base class for special callback types. + + Parameters + ---------- + name : str + The name of the callback. + retval : str + The return type of the callback. + param_types : str or list of str + The return type for each argument of the callback. + + Notes + ----- + - The reason Callback is an IET type rather than a SymPy type is + due to the fact that, when represented at the SymPy level, the IET + engine fails to bind the callback to a specific Call. Consequently, + errors occur during the creation of the call graph. + """ + # TODO: Create a common base class for Call and Callback to avoid + # having arguments=None here + def __init__(self, name, retval=None, param_types=None, arguments=None): + super().__init__(name=name) + self.retval = retval + self.param_types = as_tuple(param_types) + + @property + def callback_form(self): + """ + A string representation of the callback form. + + Notes + ----- + To be overridden by subclasses. + """ + return + + class Section(List): """ diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 9749825407..b6548ec57f 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -24,7 +24,7 @@ c_restrict_void_p, sorted_priority) from devito.types.basic import AbstractFunction, Basic from devito.types import (ArrayObject, CompositeObject, Dimension, Pointer, - IndexedData, DeviceMap) + IndexedData, DeviceMap, LocalCompositeObject) __all__ = ['FindApplications', 'FindNodes', 'FindSections', 'FindSymbols', @@ -186,11 +186,10 @@ def __init__(self, *args, compiler=None, **kwargs): '_mem_constant': 'static', '_mem_shared': '', } - _restrict_keyword = 'restrict' def _gen_struct_decl(self, obj, masked=()): """ - Convert ctypes.Struct -> cgen.Structure. + Convert ctypes.Struct and LocalCompositeObject -> cgen.Structure. """ ctype = obj._C_ctype try: @@ -201,7 +200,16 @@ def _gen_struct_decl(self, obj, masked=()): return None except TypeError: # E.g., `ctype` is of type `dtypes_lowering.CustomDtype` - return None + if isinstance(obj, LocalCompositeObject): + # TODO: Potentially re-evaluate: Setting ctype to obj allows + # _gen_struct_decl to generate a cgen.Structure from a + # LocalCompositeObject, where obj._C_ctype is a CustomDtype. + # LocalCompositeObject has a __fields__ property, + # which allows the subsequent code in this function to function + # correctly. + ctype = obj + else: + return None try: return obj._C_typedecl @@ -250,10 +258,10 @@ def _gen_value(self, obj, mode=1, masked=()): strshape = '' if isinstance(obj, (AbstractFunction, IndexedData)) and mode >= 1: if not obj._mem_stack: - strtype = '%s%s' % (strtype, self._restrict_keyword) + strtype = '%s%s' % (strtype, obj._restrict_keyword) strtype = ' '.join(qualifiers + [strtype]) - if obj.is_LocalObject and obj._C_modifier is not None and mode == 2: + if obj.is_LocalType and obj._C_modifier is not None and mode == 2: strtype += obj._C_modifier strname = obj._C_name @@ -624,6 +632,9 @@ def visit_Lambda(self, o): (', '.join(captures), ', '.join(decls), ''.join(extra))) return LambdaCollection([top, c.Block(body)]) + def visit_Callback(self, o, nested_call=False): + return CallbackArg(o) + def visit_HaloSpot(self, o): body = flatten(self._visit(i) for i in o.children) return c.Collection(body) @@ -675,8 +686,11 @@ def _operator_typedecls(self, o, mode='all'): for i in o._func_table.values(): if not i.local: continue - typedecls.extend([self._gen_struct_decl(j) for j in i.root.parameters - if xfilter(j)]) + typedecls.extend([ + self._gen_struct_decl(j) + for j in FindSymbols().visit(i.root) + if xfilter(j) + ]) typedecls = filter_sorted(typedecls, key=lambda i: i.tpname) return typedecls @@ -1427,3 +1441,12 @@ def sorted_efuncs(efuncs): CommCallable: 1 } return sorted_priority(efuncs, priority) + + +class CallbackArg(c.Generable): + + def __init__(self, callback): + self.callback = callback + + def generate(self): + yield self.callback.callback_form diff --git a/devito/operator/operator.py b/devito/operator/operator.py index 69a089e12b..3e6f17cf94 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -34,7 +34,8 @@ from devito.types import (Buffer, Grid, Evaluable, host_layer, device_layer, disk_layer) from devito.types.dimension import Thickness - +from devito.petsc.iet.passes import lower_petsc +from devito.petsc.clusters import petsc_preprocess __all__ = ['Operator'] @@ -262,6 +263,8 @@ def _lower(cls, expressions, **kwargs): """ # Create a symbol registry kwargs.setdefault('sregistry', SymbolRegistry()) + # TODO: To be updated based on changes in #2509 + kwargs.setdefault('concretize_mapper', {}) expressions = as_tuple(expressions) @@ -380,6 +383,9 @@ def _lower_clusters(cls, expressions, profiler=None, **kwargs): # Build a sequence of Clusters from a sequence of Eqs clusters = clusterize(expressions, **kwargs) + # Preprocess clusters for PETSc lowering + clusters = petsc_preprocess(clusters) + # Operation count before specialization init_ops = sum(estimate_cost(c.exprs) for c in clusters if c.is_dense) @@ -477,6 +483,9 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs): # Lower IET to a target-specific IET graph = Graph(iet, **kwargs) + + lower_petsc(graph, **kwargs) + graph = cls._specialize_iet(graph, **kwargs) # Instrument the IET for C-level profiling @@ -508,7 +517,7 @@ def dimensions(self): # During compilation other Dimensions may have been produced dimensions = FindSymbols('dimensions').visit(self) - ret.update(d for d in dimensions if d.is_PerfKnob) + ret.update(dimensions) ret = tuple(sorted(ret, key=attrgetter('name'))) diff --git a/devito/passes/iet/definitions.py b/devito/passes/iet/definitions.py index 3532169754..ef5878dacb 100644 --- a/devito/passes/iet/definitions.py +++ b/devito/passes/iet/definitions.py @@ -19,7 +19,7 @@ SizeOf, VOID, Keyword, pow_to_mul) from devito.tools import as_mapper, as_list, as_tuple, filter_sorted, flatten from devito.types import (Array, ComponentAccess, CustomDimension, DeviceMap, - DeviceRM, Eq, Symbol) + DeviceRM, Eq, Symbol, IndexedData) __all__ = ['DataManager', 'DeviceAwareDataManager', 'Storage'] @@ -256,9 +256,10 @@ def _alloc_object_array_on_low_lat_mem(self, site, obj, storage): """ Allocate an Array of Objects in the low latency memory. """ + frees = getattr(obj, '_C_free', None) decl = Definition(obj) - storage.update(obj, site, allocs=decl) + storage.update(obj, site, allocs=decl, frees=frees) def _alloc_pointed_array_on_high_bw_mem(self, site, obj, storage): """ @@ -333,6 +334,10 @@ def _inject_definitions(self, iet, storage): init = self.lang['thread-num'](retobj=tid) frees.append(Block(header=header, body=[init] + body)) frees.extend(as_list(cbody.frees) + flatten(v.frees)) + frees = sorted(frees, key=lambda x: min( + (obj._C_free_priority for obj in FindSymbols().visit(x) + if obj.is_LocalType), default=float('inf') + )) # maps/unmaps maps = as_list(cbody.maps) + flatten(v.maps) @@ -407,11 +412,10 @@ def place_definitions(self, iet, globs=None, **kwargs): # Track, to be handled by the EntryFunction being a global obj! globs.add(i) - elif i.is_ObjectArray: - self._alloc_object_array_on_low_lat_mem(iet, i, storage) - elif i.is_PointerArray: self._alloc_pointed_array_on_high_bw_mem(iet, i, storage) + else: + self._alloc_object_array_on_low_lat_mem(iet, i, storage) # Handle postponed global objects includes = set() @@ -447,7 +451,8 @@ def place_casts(self, iet, **kwargs): # (i) Dereferencing a PointerArray, e.g., `float (*r0)[.] = (float(*)[.]) pr0[.]` # (ii) Declaring a raw pointer, e.g., `float * r0 = NULL; *malloc(&(r0), ...) defines = set(FindSymbols('defines|globals').visit(iet)) - bases = sorted({i.base for i in indexeds}, key=lambda i: i.name) + bases = sorted({i.base for i in indexeds + if isinstance(i.base, IndexedData)}, key=lambda i: i.name) # Some objects don't distinguish their _C_symbol because they are known, # by construction, not to require it, thus making the generated code diff --git a/devito/passes/iet/engine.py b/devito/passes/iet/engine.py index b9b5bf15d4..a88cf58d45 100644 --- a/devito/passes/iet/engine.py +++ b/devito/passes/iet/engine.py @@ -3,7 +3,7 @@ from devito.ir.iet import (Call, ExprStmt, Iteration, SyncSpot, AsyncCallable, FindNodes, FindSymbols, MapNodes, MetaCall, Transformer, - EntryFunction, ThreadCallable, Uxreplace, + EntryFunction, FixedArgsCallable, Uxreplace, derive_parameters) from devito.ir.support import SymbolRegistry from devito.mpi.distributed import MPINeighborhood @@ -129,6 +129,7 @@ def apply(self, func, **kwargs): compiler.add_libraries(as_tuple(metadata.get('libs'))) compiler.add_library_dirs(as_tuple(metadata.get('lib_dirs')), rpath=metadata.get('rpath', False)) + compiler.add_ldflags(as_tuple(metadata.get('ldflags'))) except KeyError: pass @@ -602,12 +603,12 @@ def update_args(root, efuncs, dag): foo(..., z) : root(x, z) """ - if isinstance(root, ThreadCallable): + if isinstance(root, FixedArgsCallable): return efuncs # The parameters/arguments lists may have changed since a pass may have: # 1) introduced a new symbol - new_params = derive_parameters(root) + new_params = derive_parameters(root, drop_locals=True) # 2) defined a symbol for which no definition was available yet (e.g. # via a malloc, or a Dereference) diff --git a/devito/passes/iet/mpi.py b/devito/passes/iet/mpi.py index 930cc108b2..180d2231a6 100644 --- a/devito/passes/iet/mpi.py +++ b/devito/passes/iet/mpi.py @@ -72,7 +72,9 @@ def _hoist_invariant(iet): """ # Precompute scopes to save time - scopes = {i: Scope([e.expr for e in v]) for i, v in MapNodes().visit(iet).items()} + scopes = {} + for i, v in MapNodes(child_types=Expression).visit(iet).items(): + scopes[i] = Scope([e.expr for e in v]) # Analysis hsmapper = {} diff --git a/devito/petsc/__init__.py b/devito/petsc/__init__.py new file mode 100644 index 0000000000..2927bc960e --- /dev/null +++ b/devito/petsc/__init__.py @@ -0,0 +1 @@ +from devito.petsc.solve import * # noqa diff --git a/devito/petsc/clusters.py b/devito/petsc/clusters.py new file mode 100644 index 0000000000..e035ccbefc --- /dev/null +++ b/devito/petsc/clusters.py @@ -0,0 +1,27 @@ +from devito.tools import timed_pass +from devito.petsc.types import LinearSolveExpr + + +@timed_pass() +def petsc_preprocess(clusters): + """ + Preprocess the clusters to make them suitable for PETSc + code generation. + """ + clusters = petsc_lift(clusters) + return clusters + + +def petsc_lift(clusters): + """ + Lift the iteration space surrounding each PETSc solve to create + distinct iteration loops. + """ + processed = [] + for c in clusters: + if isinstance(c.exprs[0].rhs, LinearSolveExpr): + ispace = c.ispace.lift(c.exprs[0].rhs.fielddata.space_dimensions) + processed.append(c.rebuild(ispace=ispace)) + else: + processed.append(c) + return processed diff --git a/devito/petsc/iet/__init__.py b/devito/petsc/iet/__init__.py new file mode 100644 index 0000000000..beb6d1f2d1 --- /dev/null +++ b/devito/petsc/iet/__init__.py @@ -0,0 +1 @@ +from devito.petsc.iet import * # noqa diff --git a/devito/petsc/iet/nodes.py b/devito/petsc/iet/nodes.py new file mode 100644 index 0000000000..abb5da3acd --- /dev/null +++ b/devito/petsc/iet/nodes.py @@ -0,0 +1,31 @@ +from devito.ir.iet import Expression, Callback, FixedArgsCallable, Call +from devito.ir.equations import OpPetsc + + +class PetscMetaData(Expression): + """ + Base class for general expressions required to run a PETSc solver. + """ + def __init__(self, expr, pragmas=None, operation=OpPetsc): + super().__init__(expr, pragmas=pragmas, operation=operation) + + +class PETScCallable(FixedArgsCallable): + pass + + +class MatShellSetOp(Callback): + @property + def callback_form(self): + param_types_str = ', '.join([str(t) for t in self.param_types]) + return "(%s (*)(%s))%s" % (self.retval, param_types_str, self.name) + + +class FormFunctionCallback(Callback): + @property + def callback_form(self): + return "%s" % self.name + + +class PETScCall(Call): + pass diff --git a/devito/petsc/iet/passes.py b/devito/petsc/iet/passes.py new file mode 100644 index 0000000000..710f25a611 --- /dev/null +++ b/devito/petsc/iet/passes.py @@ -0,0 +1,248 @@ +import cgen as c +import numpy as np +from functools import cached_property + +from devito.passes.iet.engine import iet_pass +from devito.ir.iet import (Transformer, MapNodes, Iteration, BlankLine, + DummyExpr, CallableBody, List, Call, Callable, + FindNodes) +from devito.symbolics import Byref, Macro, FieldFromPointer +from devito.types import Symbol, Scalar +from devito.types.basic import DataSymbol +from devito.tools import frozendict +from devito.petsc.types import (PetscMPIInt, PetscErrorCode, MultipleFieldData, + PointerIS, Mat, LocalVec, GlobalVec, CallbackMat, SNES, + DummyArg, PetscInt, PointerDM, PointerMat, MatReuse, + CallbackPointerIS, CallbackPointerDM, JacobianStruct, + SubMatrixStruct, Initialize, Finalize, ArgvSymbol) +from devito.petsc.types.macros import petsc_func_begin_user +from devito.petsc.iet.nodes import PetscMetaData +from devito.petsc.utils import core_metadata +from devito.petsc.iet.routines import (CBBuilder, CCBBuilder, BaseObjectBuilder, + CoupledObjectBuilder, BaseSetup, CoupledSetup, + Solver, CoupledSolver, TimeDependent, + NonTimeDependent) +from devito.petsc.iet.utils import petsc_call, petsc_call_mpi + + +@iet_pass +def lower_petsc(iet, **kwargs): + # Check if PETScSolve was used + injectsolve_mapper = MapNodes(Iteration, PetscMetaData, + 'groupby').visit(iet) + + if not injectsolve_mapper: + return iet, {} + + metadata = core_metadata() + data = FindNodes(PetscMetaData).visit(iet) + + if any(filter(lambda i: isinstance(i.expr.rhs, Initialize), data)): + return initialize(iet), metadata + + if any(filter(lambda i: isinstance(i.expr.rhs, Finalize), data)): + return finalize(iet), metadata + + unique_grids = {i.expr.rhs.grid for (i,) in injectsolve_mapper.values()} + # Assumption is that all solves are on the same grid + if len(unique_grids) > 1: + raise ValueError("All PETScSolves must use the same Grid, but multiple found.") + + # Create core PETSc calls (not specific to each PETScSolve) + core = make_core_petsc_calls(objs, **kwargs) + + setup = [] + subs = {} + efuncs = {} + + for iters, (injectsolve,) in injectsolve_mapper.items(): + + builder = Builder(injectsolve, objs, iters, **kwargs) + + setup.extend(builder.solversetup.calls) + + # Transform the spatial iteration loop with the calls to execute the solver + subs.update({builder.solve.spatial_body: builder.solve.calls}) + + efuncs.update(builder.cbbuilder.efuncs) + + populate_matrix_context(efuncs, objs) + + iet = Transformer(subs).visit(iet) + + body = core + tuple(setup) + (BlankLine,) + iet.body.body + body = iet.body._rebuild(body=body) + iet = iet._rebuild(body=body) + metadata.update({'efuncs': tuple(efuncs.values())}) + return iet, metadata + + +def initialize(iet): + # should be int because the correct type for argc is a C int + # and not a int32 + argc = DataSymbol(name='argc', dtype=np.int32) + argv = ArgvSymbol(name='argv') + Help = Macro('help') + + help_string = c.Line(r'static char help[] = "This is help text.\n";') + + init_body = petsc_call('PetscInitialize', [Byref(argc), Byref(argv), Null, Help]) + init_body = CallableBody( + body=(petsc_func_begin_user, help_string, init_body), + retstmt=(Call('PetscFunctionReturn', arguments=[0]),) + ) + return iet._rebuild(body=init_body) + + +def finalize(iet): + finalize_body = petsc_call('PetscFinalize', []) + finalize_body = CallableBody( + body=(petsc_func_begin_user, finalize_body), + retstmt=(Call('PetscFunctionReturn', arguments=[0]),) + ) + return iet._rebuild(body=finalize_body) + + +def make_core_petsc_calls(objs, **kwargs): + call_mpi = petsc_call_mpi('MPI_Comm_size', [objs['comm'], Byref(objs['size'])]) + + return call_mpi, BlankLine + + +class Builder: + """ + This class is designed to support future extensions, enabling + different combinations of solver types, preconditioning methods, + and other functionalities as needed. + The class will be extended to accommodate different solver types by + returning subclasses of the objects initialised in __init__, + depending on the properties of `injectsolve`. + """ + def __init__(self, injectsolve, objs, iters, **kwargs): + self.injectsolve = injectsolve + self.objs = objs + self.iters = iters + self.kwargs = kwargs + self.coupled = isinstance(injectsolve.expr.rhs.fielddata, MultipleFieldData) + self.args = { + 'injectsolve': self.injectsolve, + 'objs': self.objs, + 'iters': self.iters, + **self.kwargs + } + self.args['solver_objs'] = self.objbuilder.solver_objs + self.args['timedep'] = self.timedep + self.args['cbbuilder'] = self.cbbuilder + + @cached_property + def objbuilder(self): + return ( + CoupledObjectBuilder(**self.args) + if self.coupled else + BaseObjectBuilder(**self.args) + ) + + @cached_property + def timedep(self): + time_mapper = self.injectsolve.expr.rhs.time_mapper + timedep_class = TimeDependent if time_mapper else NonTimeDependent + return timedep_class(**self.args) + + @cached_property + def cbbuilder(self): + return CCBBuilder(**self.args) if self.coupled else CBBuilder(**self.args) + + @cached_property + def solversetup(self): + return CoupledSetup(**self.args) if self.coupled else BaseSetup(**self.args) + + @cached_property + def solve(self): + return CoupledSolver(**self.args) if self.coupled else Solver(**self.args) + + +def populate_matrix_context(efuncs, objs): + if not objs['dummyefunc'] in efuncs.values(): + return + + subdms_expr = DummyExpr( + FieldFromPointer(objs['Subdms']._C_symbol, objs['ljacctx']), + objs['Subdms']._C_symbol + ) + fields_expr = DummyExpr( + FieldFromPointer(objs['Fields']._C_symbol, objs['ljacctx']), + objs['Fields']._C_symbol + ) + body = CallableBody( + List(body=[subdms_expr, fields_expr]), + init=(objs['begin_user'],), + retstmt=tuple([Call('PetscFunctionReturn', arguments=[0])]) + ) + name = 'PopulateMatContext' + efuncs[name] = Callable( + name, body, objs['err'], + parameters=[objs['ljacctx'], objs['Subdms'], objs['Fields']] + ) + + +# TODO: Devito MPI + PETSc testing +# if kwargs['options']['mpi'] -> communicator = grid.distributor._obj_comm +communicator = 'PETSC_COMM_WORLD' +subdms = PointerDM(name='subdms') +fields = PointerIS(name='fields') +submats = PointerMat(name='submats') +rows = PointerIS(name='rows') +cols = PointerIS(name='cols') + + +# A static dict containing shared symbols and objects that are not +# unique to each PETScSolve. +# Many of these objects are used as arguments in callback functions to make +# the C code cleaner and more modular. This is also a step toward leveraging +# Devito's `reuse_efuncs` functionality, allowing reuse of efuncs when +# they are semantically identical. +objs = frozendict({ + 'size': PetscMPIInt(name='size'), + 'comm': communicator, + 'err': PetscErrorCode(name='err'), + 'block': CallbackMat('block'), + 'submat_arr': PointerMat(name='submat_arr'), + 'subblockrows': PetscInt('subblockrows'), + 'subblockcols': PetscInt('subblockcols'), + 'rowidx': PetscInt('rowidx'), + 'colidx': PetscInt('colidx'), + 'J': Mat('J'), + 'X': GlobalVec('X'), + 'xloc': LocalVec('xloc'), + 'Y': GlobalVec('Y'), + 'yloc': LocalVec('yloc'), + 'F': GlobalVec('F'), + 'floc': LocalVec('floc'), + 'B': GlobalVec('B'), + 'nfields': PetscInt('nfields'), + 'irow': PointerIS(name='irow'), + 'icol': PointerIS(name='icol'), + 'nsubmats': Scalar('nsubmats', dtype=np.int32), + 'matreuse': MatReuse('scall'), + 'snes': SNES('snes'), + 'rows': rows, + 'cols': cols, + 'Subdms': subdms, + 'LocalSubdms': CallbackPointerDM(name='subdms'), + 'Fields': fields, + 'LocalFields': CallbackPointerIS(name='fields'), + 'Submats': submats, + 'ljacctx': JacobianStruct( + fields=[subdms, fields, submats], modifier=' *' + ), + 'subctx': SubMatrixStruct(fields=[rows, cols]), + 'Null': Macro('NULL'), + 'dummyctx': Symbol('lctx'), + 'dummyptr': DummyArg('dummy'), + 'dummyefunc': Symbol('dummyefunc'), + 'dof': PetscInt('dof'), + 'begin_user': c.Line('PetscFunctionBeginUser;'), +}) + +# Move to macros file? +Null = Macro('NULL') diff --git a/devito/petsc/iet/routines.py b/devito/petsc/iet/routines.py new file mode 100644 index 0000000000..a70e987e71 --- /dev/null +++ b/devito/petsc/iet/routines.py @@ -0,0 +1,1573 @@ +from collections import OrderedDict +from functools import cached_property + +from devito.ir.iet import (Call, FindSymbols, List, Uxreplace, CallableBody, + Dereference, DummyExpr, BlankLine, Callable, FindNodes, + retrieve_iteration_tree, filter_iterations, Iteration) +from devito.symbolics import (Byref, FieldFromPointer, cast_mapper, VOIDP, + FieldFromComposite, IntDiv, Deref, Mod) +from devito.symbolics.unevaluation import Mul +from devito.types.basic import AbstractFunction +from devito.types import Temp, Dimension +from devito.tools import filter_ordered + +from devito.petsc.types import PETScArray +from devito.petsc.iet.nodes import (PETScCallable, FormFunctionCallback, + MatShellSetOp, PetscMetaData) +from devito.petsc.iet.utils import petsc_call, petsc_struct +from devito.petsc.utils import solver_mapper +from devito.petsc.types import (DM, Mat, LocalVec, GlobalVec, KSP, PC, SNES, + PetscInt, StartPtr, PointerIS, PointerDM, VecScatter, + DMCast, JacobianStructCast, JacobianStruct, + SubMatrixStruct, CallbackDM) + + +class CBBuilder: + """ + Build IET routines to generate PETSc callback functions. + """ + def __init__(self, **kwargs): + + self.rcompile = kwargs.get('rcompile', None) + self.sregistry = kwargs.get('sregistry', None) + self.concretize_mapper = kwargs.get('concretize_mapper', {}) + self.timedep = kwargs.get('timedep') + self.objs = kwargs.get('objs') + self.solver_objs = kwargs.get('solver_objs') + self.injectsolve = kwargs.get('injectsolve') + + self._efuncs = OrderedDict() + self._struct_params = [] + + self._main_matvec_callback = None + self._main_formfunc_callback = None + self._user_struct_callback = None + # TODO: Test pickling. The mutability of these lists + # could cause issues when pickling? + self._matvecs = [] + self._formfuncs = [] + self._formrhs = [] + + self._make_core() + self._efuncs = self._uxreplace_efuncs() + + @property + def efuncs(self): + return self._efuncs + + @property + def struct_params(self): + return self._struct_params + + @property + def filtered_struct_params(self): + return filter_ordered(self.struct_params) + + @property + def main_matvec_callback(self): + """ + This is the matvec callback associated with the whole Jacobian i.e + is set in the main kernel via + `PetscCall(MatShellSetOperation(J,MATOP_MULT,(void (*)(void))...));` + """ + return self._matvecs[0] + + @property + def main_formfunc_callback(self): + return self._formfuncs[0] + + @property + def matvecs(self): + return self._matvecs + + @property + def formfuncs(self): + return self._formfuncs + + @property + def formrhs(self): + return self._formrhs + + @property + def user_struct_callback(self): + return self._user_struct_callback + + def _make_core(self): + fielddata = self.injectsolve.expr.rhs.fielddata + self._make_matvec(fielddata, fielddata.matvecs) + self._make_formfunc(fielddata) + self._make_formrhs(fielddata) + self._make_user_struct_callback() + + def _make_matvec(self, fielddata, matvecs, prefix='MatMult'): + # Compile matvec `eqns` into an IET via recursive compilation + irs_matvec, _ = self.rcompile(matvecs, + options={'mpi': False}, sregistry=self.sregistry, + concretize_mapper=self.concretize_mapper) + body_matvec = self._create_matvec_body(List(body=irs_matvec.uiet.body), + fielddata) + + objs = self.objs + cb = PETScCallable( + self.sregistry.make_name(prefix=prefix), + body_matvec, + retval=objs['err'], + parameters=(objs['J'], objs['X'], objs['Y']) + ) + self._matvecs.append(cb) + self._efuncs[cb.name] = cb + + def _create_matvec_body(self, body, fielddata): + linsolve_expr = self.injectsolve.expr.rhs + objs = self.objs + sobjs = self.solver_objs + + dmda = sobjs['callbackdm'] + ctx = objs['dummyctx'] + xlocal = objs['xloc'] + ylocal = objs['yloc'] + y_matvec = fielddata.arrays['y'] + x_matvec = fielddata.arrays['x'] + + body = self.timedep.uxreplace_time(body) + + fields = self._dummy_fields(body) + + mat_get_dm = petsc_call('MatGetDM', [objs['J'], Byref(dmda)]) + + dm_get_app_context = petsc_call( + 'DMGetApplicationContext', [dmda, Byref(ctx._C_symbol)] + ) + + dm_get_local_xvec = petsc_call( + 'DMGetLocalVector', [dmda, Byref(xlocal)] + ) + + global_to_local_begin = petsc_call( + 'DMGlobalToLocalBegin', [dmda, objs['X'], + insert_vals, xlocal] + ) + + global_to_local_end = petsc_call('DMGlobalToLocalEnd', [ + dmda, objs['X'], insert_vals, xlocal + ]) + + dm_get_local_yvec = petsc_call( + 'DMGetLocalVector', [dmda, Byref(ylocal)] + ) + + vec_get_array_y = petsc_call( + 'VecGetArray', [ylocal, Byref(y_matvec._C_symbol)] + ) + + vec_get_array_x = petsc_call( + 'VecGetArray', [xlocal, Byref(x_matvec._C_symbol)] + ) + + dm_get_local_info = petsc_call( + 'DMDAGetLocalInfo', [dmda, Byref(linsolve_expr.localinfo)] + ) + + vec_restore_array_y = petsc_call( + 'VecRestoreArray', [ylocal, Byref(y_matvec._C_symbol)] + ) + + vec_restore_array_x = petsc_call( + 'VecRestoreArray', [xlocal, Byref(x_matvec._C_symbol)] + ) + + dm_local_to_global_begin = petsc_call('DMLocalToGlobalBegin', [ + dmda, ylocal, insert_vals, objs['Y'] + ]) + + dm_local_to_global_end = petsc_call('DMLocalToGlobalEnd', [ + dmda, ylocal, insert_vals, objs['Y'] + ]) + + dm_restore_local_xvec = petsc_call( + 'DMRestoreLocalVector', [dmda, Byref(xlocal)] + ) + + dm_restore_local_yvec = petsc_call( + 'DMRestoreLocalVector', [dmda, Byref(ylocal)] + ) + + # TODO: Some of the calls are placed in the `stacks` argument of the + # `CallableBody` to ensure that they precede the `cast` statements. The + # 'casts' depend on the calls, so this order is necessary. By doing this, + # you avoid having to manually construct the `casts` and can allow + # Devito to handle their construction. This is a temporary solution and + # should be revisited + + body = body._rebuild( + body=body.body + + (vec_restore_array_y, + vec_restore_array_x, + dm_local_to_global_begin, + dm_local_to_global_end, + dm_restore_local_xvec, + dm_restore_local_yvec) + ) + + stacks = ( + mat_get_dm, + dm_get_app_context, + dm_get_local_xvec, + global_to_local_begin, + global_to_local_end, + dm_get_local_yvec, + vec_get_array_y, + vec_get_array_x, + dm_get_local_info + ) + + # Dereference function data in struct + dereference_funcs = [Dereference(i, ctx) for i in + fields if isinstance(i.function, AbstractFunction)] + + matvec_body = CallableBody( + List(body=body), + init=(objs['begin_user'],), + stacks=stacks+tuple(dereference_funcs), + retstmt=(Call('PetscFunctionReturn', arguments=[0]),) + ) + + # Replace non-function data with pointer to data in struct + subs = {i._C_symbol: FieldFromPointer(i._C_symbol, ctx) for i in fields} + matvec_body = Uxreplace(subs).visit(matvec_body) + + self._struct_params.extend(fields) + return matvec_body + + def _make_formfunc(self, fielddata): + formfuncs = fielddata.formfuncs + # Compile formfunc `eqns` into an IET via recursive compilation + irs_formfunc, _ = self.rcompile( + formfuncs, options={'mpi': False}, sregistry=self.sregistry, + concretize_mapper=self.concretize_mapper + ) + body_formfunc = self._create_formfunc_body( + List(body=irs_formfunc.uiet.body), fielddata + ) + objs = self.objs + cb = PETScCallable( + self.sregistry.make_name(prefix='FormFunction'), + body_formfunc, + retval=objs['err'], + parameters=(objs['snes'], objs['X'], objs['F'], objs['dummyptr']) + ) + self._formfuncs.append(cb) + self._efuncs[cb.name] = cb + + def _create_formfunc_body(self, body, fielddata): + linsolve_expr = self.injectsolve.expr.rhs + objs = self.objs + sobjs = self.solver_objs + + dmda = sobjs['callbackdm'] + ctx = objs['dummyctx'] + + body = self.timedep.uxreplace_time(body) + + fields = self._dummy_fields(body) + self._struct_params.extend(fields) + + f_formfunc = fielddata.arrays['f'] + x_formfunc = fielddata.arrays['x'] + + dm_cast = DummyExpr(dmda, DMCast(objs['dummyptr']), init=True) + + dm_get_app_context = petsc_call( + 'DMGetApplicationContext', [dmda, Byref(ctx._C_symbol)] + ) + + dm_get_local_xvec = petsc_call( + 'DMGetLocalVector', [dmda, Byref(objs['xloc'])] + ) + + global_to_local_begin = petsc_call( + 'DMGlobalToLocalBegin', [dmda, objs['X'], + insert_vals, objs['xloc']] + ) + + global_to_local_end = petsc_call('DMGlobalToLocalEnd', [ + dmda, objs['X'], insert_vals, objs['xloc'] + ]) + + dm_get_local_yvec = petsc_call( + 'DMGetLocalVector', [dmda, Byref(objs['floc'])] + ) + + vec_get_array_y = petsc_call( + 'VecGetArray', [objs['floc'], Byref(f_formfunc._C_symbol)] + ) + + vec_get_array_x = petsc_call( + 'VecGetArray', [objs['xloc'], Byref(x_formfunc._C_symbol)] + ) + + dm_get_local_info = petsc_call( + 'DMDAGetLocalInfo', [dmda, Byref(linsolve_expr.localinfo)] + ) + + vec_restore_array_y = petsc_call( + 'VecRestoreArray', [objs['floc'], Byref(f_formfunc._C_symbol)] + ) + + vec_restore_array_x = petsc_call( + 'VecRestoreArray', [objs['xloc'], Byref(x_formfunc._C_symbol)] + ) + + dm_local_to_global_begin = petsc_call('DMLocalToGlobalBegin', [ + dmda, objs['floc'], insert_vals, objs['F'] + ]) + + dm_local_to_global_end = petsc_call('DMLocalToGlobalEnd', [ + dmda, objs['floc'], insert_vals, objs['F'] + ]) + + dm_restore_local_xvec = petsc_call( + 'DMRestoreLocalVector', [dmda, Byref(objs['xloc'])] + ) + + dm_restore_local_yvec = petsc_call( + 'DMRestoreLocalVector', [dmda, Byref(objs['floc'])] + ) + + body = body._rebuild( + body=body.body + + (vec_restore_array_y, + vec_restore_array_x, + dm_local_to_global_begin, + dm_local_to_global_end, + dm_restore_local_xvec, + dm_restore_local_yvec) + ) + + stacks = ( + dm_cast, + dm_get_app_context, + dm_get_local_xvec, + global_to_local_begin, + global_to_local_end, + dm_get_local_yvec, + vec_get_array_y, + vec_get_array_x, + dm_get_local_info + ) + + # Dereference function data in struct + dereference_funcs = [Dereference(i, ctx) for i in + fields if isinstance(i.function, AbstractFunction)] + + formfunc_body = CallableBody( + List(body=body), + init=(objs['begin_user'],), + stacks=stacks+tuple(dereference_funcs), + retstmt=(Call('PetscFunctionReturn', arguments=[0]),)) + + # Replace non-function data with pointer to data in struct + subs = {i._C_symbol: FieldFromPointer(i._C_symbol, ctx) for i in fields} + + return Uxreplace(subs).visit(formfunc_body) + + def _make_formrhs(self, fielddata): + formrhs = fielddata.formrhs + sobjs = self.solver_objs + + # Compile formrhs `eqns` into an IET via recursive compilation + irs_formrhs, _ = self.rcompile( + formrhs, options={'mpi': False}, sregistry=self.sregistry, + concretize_mapper=self.concretize_mapper + ) + body_formrhs = self._create_form_rhs_body( + List(body=irs_formrhs.uiet.body), fielddata + ) + objs = self.objs + cb = PETScCallable( + self.sregistry.make_name(prefix='FormRHS'), + body_formrhs, + retval=objs['err'], + parameters=(sobjs['callbackdm'], objs['B']) + ) + self._formrhs.append(cb) + self._efuncs[cb.name] = cb + + def _create_form_rhs_body(self, body, fielddata): + linsolve_expr = self.injectsolve.expr.rhs + objs = self.objs + sobjs = self.solver_objs + + dmda = sobjs['callbackdm'] + ctx = objs['dummyctx'] + + dm_get_local = petsc_call( + 'DMGetLocalVector', [dmda, Byref(sobjs['blocal'])] + ) + + dm_global_to_local_begin = petsc_call( + 'DMGlobalToLocalBegin', [dmda, objs['B'], + insert_vals, sobjs['blocal']] + ) + + dm_global_to_local_end = petsc_call('DMGlobalToLocalEnd', [ + dmda, objs['B'], insert_vals, + sobjs['blocal'] + ]) + + b_arr = fielddata.arrays['b'] + + vec_get_array = petsc_call( + 'VecGetArray', [sobjs['blocal'], Byref(b_arr._C_symbol)] + ) + + dm_get_local_info = petsc_call( + 'DMDAGetLocalInfo', [dmda, Byref(linsolve_expr.localinfo)] + ) + + body = self.timedep.uxreplace_time(body) + + fields = self._dummy_fields(body) + self._struct_params.extend(fields) + + dm_get_app_context = petsc_call( + 'DMGetApplicationContext', [dmda, Byref(ctx._C_symbol)] + ) + + dm_local_to_global_begin = petsc_call('DMLocalToGlobalBegin', [ + dmda, sobjs['blocal'], insert_vals, + objs['B'] + ]) + + dm_local_to_global_end = petsc_call('DMLocalToGlobalEnd', [ + dmda, sobjs['blocal'], insert_vals, + objs['B'] + ]) + + vec_restore_array = petsc_call( + 'VecRestoreArray', [sobjs['blocal'], Byref(b_arr._C_symbol)] + ) + + body = body._rebuild(body=body.body + ( + dm_local_to_global_begin, dm_local_to_global_end, vec_restore_array + )) + + stacks = ( + dm_get_local, + dm_global_to_local_begin, + dm_global_to_local_end, + vec_get_array, + dm_get_app_context, + dm_get_local_info + ) + + # Dereference function data in struct + dereference_funcs = [Dereference(i, ctx) for i in + fields if isinstance(i.function, AbstractFunction)] + + formrhs_body = CallableBody( + List(body=[body]), + init=(objs['begin_user'],), + stacks=stacks+tuple(dereference_funcs), + retstmt=(Call('PetscFunctionReturn', arguments=[0]),) + ) + + # Replace non-function data with pointer to data in struct + subs = {i._C_symbol: FieldFromPointer(i._C_symbol, ctx) for + i in fields if not isinstance(i.function, AbstractFunction)} + + return Uxreplace(subs).visit(formrhs_body) + + def _make_user_struct_callback(self): + """ + This is the struct initialised inside the main kernel and + attached to the DM via DMSetApplicationContext. + # TODO: this could be common between all PETScSolves instead? + """ + mainctx = self.solver_objs['userctx'] = petsc_struct( + self.sregistry.make_name(prefix='ctx'), + self.filtered_struct_params, + self.sregistry.make_name(prefix='UserCtx'), + ) + body = [ + DummyExpr(FieldFromPointer(i._C_symbol, mainctx), i._C_symbol) + for i in mainctx.callback_fields + ] + struct_callback_body = CallableBody( + List(body=body), init=(self.objs['begin_user'],), + retstmt=(Call('PetscFunctionReturn', arguments=[0]),) + ) + cb = Callable( + self.sregistry.make_name(prefix='PopulateUserContext'), + struct_callback_body, self.objs['err'], + parameters=[mainctx] + ) + self._efuncs[cb.name] = cb + self._user_struct_callback = cb + + def _dummy_fields(self, iet): + # Place all context data required by the shell routines into a struct + fields = [f.function for f in FindSymbols('basics').visit(iet)] + fields = [f for f in fields if not isinstance(f.function, (PETScArray, Temp))] + fields = [ + f for f in fields if not (f.is_Dimension and not (f.is_Time or f.is_Modulo)) + ] + return fields + + def _uxreplace_efuncs(self): + sobjs = self.solver_objs + luserctx = petsc_struct( + sobjs['userctx'].name, + self.filtered_struct_params, + sobjs['userctx'].pname, + modifier=' *' + ) + mapper = {} + visitor = Uxreplace({self.objs['dummyctx']: luserctx}) + for k, v in self._efuncs.items(): + mapper.update({k: visitor.visit(v)}) + return mapper + + +class CCBBuilder(CBBuilder): + def __init__(self, **kwargs): + self._submatrices_callback = None + super().__init__(**kwargs) + + @property + def submatrices_callback(self): + return self._submatrices_callback + + @property + def submatrices(self): + return self.injectsolve.expr.rhs.fielddata.submatrices + + @property + def main_matvec_callback(self): + """ + This is the matvec callback associated with the whole Jacobian i.e + is set in the main kernel via + `PetscCall(MatShellSetOperation(J,MATOP_MULT,(void (*)(void))MyMatShellMult));` + """ + return self._main_matvec_callback + + @property + def main_formfunc_callback(self): + return self._main_formfunc_callback + + def _make_core(self): + injectsolve = self.injectsolve + targets = injectsolve.expr.rhs.fielddata.targets + all_fielddata = injectsolve.expr.rhs.fielddata + + for t in targets: + data = all_fielddata.get_field_data(t) + self._make_formfunc(data) + self._make_formrhs(data) + + row_matvecs = all_fielddata.submatrices.submatrices[t] + for submat, mtvs in row_matvecs.items(): + if mtvs['matvecs']: + self._make_matvec(data, mtvs['matvecs'], prefix=f'{submat}_MatMult') + + self._make_user_struct_callback() + self._make_whole_matvec() + self._make_whole_formfunc() + self._create_submatrices() + self._efuncs['PopulateMatContext'] = self.objs['dummyefunc'] + + def _make_whole_matvec(self): + objs = self.objs + body = self._whole_matvec_body() + + cb = PETScCallable( + self.sregistry.make_name(prefix='WholeMatMult'), + List(body=body), + retval=objs['err'], + parameters=(objs['J'], objs['X'], objs['Y']) + ) + self._main_matvec_callback = cb + self._efuncs[cb.name] = cb + + def _whole_matvec_body(self): + objs = self.objs + sobjs = self.solver_objs + + jctx = objs['ljacctx'] + ctx_main = petsc_call('MatShellGetContext', [objs['J'], Byref(jctx)]) + + nonzero_submats = self.submatrices.nonzero_submatrix_keys + + calls = () + for sm in nonzero_submats: + idx = self.submatrices.submat_to_index[sm] + ctx = sobjs[f'{sm}ctx'] + X = sobjs[f'{sm}X'] + Y = sobjs[f'{sm}Y'] + rows = objs['rows'].base + cols = objs['cols'].base + sm_indexed = objs['Submats'].indexed[idx] + + calls += ( + DummyExpr(sobjs[sm], FieldFromPointer(sm_indexed, jctx)), + petsc_call('MatShellGetContext', [sobjs[sm], Byref(ctx)]), + petsc_call( + 'VecGetSubVector', + [objs['X'], Deref(FieldFromPointer(cols, ctx)), Byref(X)] + ), + petsc_call( + 'VecGetSubVector', + [objs['Y'], Deref(FieldFromPointer(rows, ctx)), Byref(Y)] + ), + petsc_call('MatMult', [sobjs[sm], X, Y]), + petsc_call( + 'VecRestoreSubVector', + [objs['X'], Deref(FieldFromPointer(cols, ctx)), Byref(X)] + ), + petsc_call( + 'VecRestoreSubVector', + [objs['Y'], Deref(FieldFromPointer(rows, ctx)), Byref(Y)] + ), + ) + return CallableBody( + List(body=(ctx_main, BlankLine) + calls), + init=(objs['begin_user'],), + retstmt=(Call('PetscFunctionReturn', arguments=[0]),) + ) + + def _make_whole_formfunc(self): + objs = self.objs + body = self._whole_formfunc_body() + + cb = PETScCallable( + self.sregistry.make_name(prefix='WholeFormFunc'), + List(body=body), + retval=objs['err'], + parameters=(objs['snes'], objs['X'], objs['F'], objs['dummyptr']) + ) + self._main_formfunc_callback = cb + self._efuncs[cb.name] = cb + + def _whole_formfunc_body(self): + objs = self.objs + sobjs = self.solver_objs + + ljacctx = objs['ljacctx'] + struct_cast = DummyExpr(ljacctx, JacobianStructCast(objs['dummyptr'])) + X = objs['X'] + F = objs['F'] + + targets = self.injectsolve.expr.rhs.fielddata.targets + + deref_subdms = Dereference(objs['LocalSubdms'], ljacctx) + deref_fields = Dereference(objs['LocalFields'], ljacctx) + + calls = () + for i, t in enumerate(targets): + field_ptr = FieldFromPointer(objs['LocalFields'].indexed[i], ljacctx) + x_name = f'Xglobal{t.name}' + f_name = f'Fglobal{t.name}' + + calls += ( + petsc_call('VecGetSubVector', [X, field_ptr, Byref(sobjs[x_name])]), + petsc_call('VecGetSubVector', [F, field_ptr, Byref(sobjs[f_name])]), + petsc_call(self.formfuncs[i].name, [objs['snes'], sobjs[x_name], + sobjs[f_name], VOIDP(objs['LocalSubdms'].indexed[i])]), + petsc_call('VecRestoreSubVector', [X, field_ptr, Byref(sobjs[x_name])]), + petsc_call('VecRestoreSubVector', [F, field_ptr, Byref(sobjs[f_name])]), + ) + return CallableBody( + List(body=calls + (BlankLine,)), + init=(objs['begin_user'],), + stacks=(struct_cast, deref_subdms, deref_fields), + retstmt=(Call('PetscFunctionReturn', arguments=[0]),) + ) + + def _create_submatrices(self): + body = self._submat_callback_body() + objs = self.objs + params = ( + objs['J'], + objs['nfields'], + objs['irow'], + objs['icol'], + objs['matreuse'], + objs['Submats'], + ) + cb = PETScCallable( + self.sregistry.make_name(prefix='MatCreateSubMatrices'), + body, + retval=objs['err'], + parameters=params + ) + self._submatrices_callback = cb + self._efuncs[cb.name] = cb + + def _submat_callback_body(self): + objs = self.objs + sobjs = self.solver_objs + + n_submats = DummyExpr( + objs['nsubmats'], Mul(objs['nfields'], objs['nfields']) + ) + + malloc_submats = petsc_call('PetscCalloc1', [objs['nsubmats'], objs['Submats']]) + + mat_get_dm = petsc_call('MatGetDM', [objs['J'], Byref(sobjs['callbackdm'])]) + + dm_get_app = petsc_call( + 'DMGetApplicationContext', [sobjs['callbackdm'], Byref(objs['dummyctx'])] + ) + + get_ctx = petsc_call('MatShellGetContext', [objs['J'], Byref(objs['ljacctx'])]) + + Null = objs['Null'] + dm_get_info = petsc_call( + 'DMDAGetInfo', [ + sobjs['callbackdm'], Null, Byref(sobjs['M']), Byref(sobjs['N']), + Null, Null, Null, Null, Byref(objs['dof']), Null, Null, Null, Null, Null + ] + ) + subblock_rows = DummyExpr(objs['subblockrows'], Mul(sobjs['M'], sobjs['N'])) + subblock_cols = DummyExpr(objs['subblockcols'], Mul(sobjs['M'], sobjs['N'])) + + ptr = DummyExpr(objs['submat_arr']._C_symbol, Deref(objs['Submats']), init=True) + + mat_create = petsc_call('MatCreate', [self.objs['comm'], Byref(objs['block'])]) + + mat_set_sizes = petsc_call( + 'MatSetSizes', [ + objs['block'], 'PETSC_DECIDE', 'PETSC_DECIDE', + objs['subblockrows'], objs['subblockcols'] + ] + ) + + mat_set_type = petsc_call('MatSetType', [objs['block'], 'MATSHELL']) + + malloc = petsc_call('PetscMalloc1', [1, Byref(objs['subctx'])]) + i = Dimension(name='i') + + row_idx = DummyExpr(objs['rowidx'], IntDiv(i, objs['dof'])) + col_idx = DummyExpr(objs['colidx'], Mod(i, objs['dof'])) + + deref_subdm = Dereference(objs['Subdms'], objs['ljacctx']) + + set_rows = DummyExpr( + FieldFromPointer(objs['rows'].base, objs['subctx']), + Byref(objs['irow'].indexed[objs['rowidx']]) + ) + set_cols = DummyExpr( + FieldFromPointer(objs['cols'].base, objs['subctx']), + Byref(objs['icol'].indexed[objs['colidx']]) + ) + dm_set_ctx = petsc_call( + 'DMSetApplicationContext', [ + objs['Subdms'].indexed[objs['rowidx']], objs['dummyctx'] + ] + ) + matset_dm = petsc_call('MatSetDM', [ + objs['block'], objs['Subdms'].indexed[objs['rowidx']] + ]) + + set_ctx = petsc_call('MatShellSetContext', [objs['block'], objs['subctx']]) + + mat_setup = petsc_call('MatSetUp', [objs['block']]) + + assign_block = DummyExpr(objs['submat_arr'].indexed[i], objs['block']) + + iter_body = ( + mat_create, + mat_set_sizes, + mat_set_type, + malloc, + row_idx, + col_idx, + set_rows, + set_cols, + dm_set_ctx, + matset_dm, + set_ctx, + mat_setup, + assign_block + ) + + upper_bound = objs['nsubmats'] - 1 + iteration = Iteration(List(body=iter_body), i, upper_bound) + + nonzero_submats = self.submatrices.nonzero_submatrix_keys + matvec_lookup = {mv.name.split('_')[0]: mv for mv in self.matvecs} + + matmult_op = [ + petsc_call( + 'MatShellSetOperation', + [ + objs['submat_arr'].indexed[self.submatrices.submat_to_index[sb]], + 'MATOP_MULT', + MatShellSetOp(matvec_lookup[sb].name, void, void), + ], + ) + for sb in nonzero_submats if sb in matvec_lookup + ] + + body = [ + n_submats, + malloc_submats, + mat_get_dm, + dm_get_app, + dm_get_info, + subblock_rows, + subblock_cols, + ptr, + BlankLine, + iteration, + ] + matmult_op + + return CallableBody( + List(body=tuple(body)), + init=(objs['begin_user'],), + stacks=(get_ctx, deref_subdm), + retstmt=(Call('PetscFunctionReturn', arguments=[0]),) + ) + + +class BaseObjectBuilder: + """ + A base class for constructing objects needed for a PETSc solver. + Designed to be extended by subclasses, which can override the `_extend_build` + method to support specific use cases. + """ + def __init__(self, **kwargs): + self.injectsolve = kwargs.get('injectsolve') + self.objs = kwargs.get('objs') + self.sregistry = kwargs.get('sregistry') + self.fielddata = self.injectsolve.expr.rhs.fielddata + self.solver_objs = self._build() + + def _build(self): + """ + # TODO: update docs + Constructs the core dictionary of solver objects and allows + subclasses to extend or modify it via `_extend_build`. + Returns: + dict: A dictionary containing the following objects: + - 'Jac' (Mat): A matrix representing the jacobian. + - 'xglobal' (GlobalVec): The global solution vector. + - 'xlocal' (LocalVec): The local solution vector. + - 'bglobal': (GlobalVec) Global RHS vector `b`, where `F(x) = b`. + - 'blocal': (LocalVec) Local RHS vector `b`, where `F(x) = b`. + - 'ksp': (KSP) Krylov solver object that manages the linear solver. + - 'pc': (PC) Preconditioner object. + - 'snes': (SNES) Nonlinear solver object. + - 'localsize' (PetscInt): The local length of the solution vector. + - 'dmda' (DM): The DMDA object associated with this solve, linked to + the SNES object via `SNESSetDM`. + - 'callbackdm' (CallbackDM): The DM object accessed within callback + functions via `SNESGetDM`. + """ + sreg = self.sregistry + targets = self.fielddata.targets + base_dict = { + 'Jac': Mat(sreg.make_name(prefix='J')), + 'xglobal': GlobalVec(sreg.make_name(prefix='xglobal')), + 'xlocal': LocalVec(sreg.make_name(prefix='xlocal')), + 'bglobal': GlobalVec(sreg.make_name(prefix='bglobal')), + 'blocal': LocalVec(sreg.make_name(prefix='blocal')), + 'ksp': KSP(sreg.make_name(prefix='ksp')), + 'pc': PC(sreg.make_name(prefix='pc')), + 'snes': SNES(sreg.make_name(prefix='snes')), + 'localsize': PetscInt(sreg.make_name(prefix='localsize')), + 'dmda': DM(sreg.make_name(prefix='da'), dofs=len(targets)), + 'callbackdm': CallbackDM(sreg.make_name(prefix='dm')), + } + self._target_dependent(base_dict) + return self._extend_build(base_dict) + + def _target_dependent(self, base_dict): + """ + '_ptr' (StartPtr): A pointer to the beginning of the solution array + that will be updated at each time step. + """ + sreg = self.sregistry + target = self.fielddata.target + base_dict[f'{target.name}_ptr'] = StartPtr( + sreg.make_name(prefix=f'{target.name}_ptr'), target.dtype + ) + + def _extend_build(self, base_dict): + """ + Subclasses can override this method to extend or modify the + base dictionary of solver objects. + """ + return base_dict + + +class CoupledObjectBuilder(BaseObjectBuilder): + def _extend_build(self, base_dict): + injectsolve = self.injectsolve + sreg = self.sregistry + objs = self.objs + targets = self.fielddata.targets + + base_dict['fields'] = PointerIS( + name=sreg.make_name(prefix='fields'), nindices=len(targets) + ) + base_dict['subdms'] = PointerDM( + name=sreg.make_name(prefix='subdms'), nindices=len(targets) + ) + base_dict['nfields'] = PetscInt(sreg.make_name(prefix='nfields')) + + space_dims = len(self.fielddata.grid.dimensions) + + dim_labels = ["M", "N", "P"] + base_dict.update({ + dim_labels[i]: PetscInt(dim_labels[i]) for i in range(space_dims) + }) + + submatrices = injectsolve.expr.rhs.fielddata.submatrices + submatrix_keys = submatrices.submatrix_keys + + base_dict['jacctx'] = JacobianStruct( + name=sreg.make_name(prefix=objs['ljacctx'].name), + fields=objs['ljacctx'].fields, + ) + + for key in submatrix_keys: + base_dict[key] = Mat(name=key) + base_dict[f'{key}ctx'] = SubMatrixStruct( + name=f'{key}ctx', + fields=objs['subctx'].fields, + ) + base_dict[f'{key}X'] = LocalVec(f'{key}X') + base_dict[f'{key}Y'] = LocalVec(f'{key}Y') + base_dict[f'{key}F'] = LocalVec(f'{key}F') + + return base_dict + + def _target_dependent(self, base_dict): + sreg = self.sregistry + targets = self.fielddata.targets + for t in targets: + name = t.name + base_dict[f'{name}_ptr'] = StartPtr( + sreg.make_name(prefix=f'{name}_ptr'), t.dtype + ) + base_dict[f'xlocal{name}'] = LocalVec( + sreg.make_name(prefix=f'xlocal{name}'), liveness='eager' + ) + base_dict[f'Fglobal{name}'] = LocalVec( + sreg.make_name(prefix=f'Fglobal{name}'), liveness='eager' + ) + base_dict[f'Xglobal{name}'] = LocalVec( + sreg.make_name(prefix=f'Xglobal{name}') + ) + base_dict[f'xglobal{name}'] = GlobalVec( + sreg.make_name(prefix=f'xglobal{name}') + ) + base_dict[f'blocal{name}'] = LocalVec( + sreg.make_name(prefix=f'blocal{name}'), liveness='eager' + ) + base_dict[f'bglobal{name}'] = GlobalVec( + sreg.make_name(prefix=f'bglobal{name}') + ) + base_dict[f'da{name}'] = DM( + sreg.make_name(prefix=f'da{name}'), liveness='eager' + ) + base_dict[f'scatter{name}'] = VecScatter( + sreg.make_name(prefix=f'scatter{name}') + ) + + +class BaseSetup: + def __init__(self, **kwargs): + self.injectsolve = kwargs.get('injectsolve') + self.objs = kwargs.get('objs') + self.solver_objs = kwargs.get('solver_objs') + self.cbbuilder = kwargs.get('cbbuilder') + self.fielddata = self.injectsolve.expr.rhs.fielddata + self.calls = self._setup() + + @property + def snes_ctx(self): + """ + The [optional] context for private data for the function evaluation routine. + https://petsc.org/main/manualpages/SNES/SNESSetFunction/ + """ + return VOIDP(self.solver_objs['dmda']) + + def _setup(self): + objs = self.objs + sobjs = self.solver_objs + + dmda = sobjs['dmda'] + + solver_params = self.injectsolve.expr.rhs.solver_parameters + + snes_create = petsc_call('SNESCreate', [objs['comm'], Byref(sobjs['snes'])]) + + snes_set_dm = petsc_call('SNESSetDM', [sobjs['snes'], dmda]) + + create_matrix = petsc_call('DMCreateMatrix', [dmda, Byref(sobjs['Jac'])]) + + # NOTE: Assuming all solves are linear for now + snes_set_type = petsc_call('SNESSetType', [sobjs['snes'], 'SNESKSPONLY']) + + snes_set_jac = petsc_call( + 'SNESSetJacobian', [sobjs['snes'], sobjs['Jac'], + sobjs['Jac'], 'MatMFFDComputeJacobian', objs['Null']] + ) + + global_x = petsc_call('DMCreateGlobalVector', + [dmda, Byref(sobjs['xglobal'])]) + + global_b = petsc_call('DMCreateGlobalVector', + [dmda, Byref(sobjs['bglobal'])]) + + snes_get_ksp = petsc_call('SNESGetKSP', + [sobjs['snes'], Byref(sobjs['ksp'])]) + + ksp_set_tols = petsc_call( + 'KSPSetTolerances', [sobjs['ksp'], solver_params['ksp_rtol'], + solver_params['ksp_atol'], solver_params['ksp_divtol'], + solver_params['ksp_max_it']] + ) + + ksp_set_type = petsc_call( + 'KSPSetType', [sobjs['ksp'], solver_mapper[solver_params['ksp_type']]] + ) + + ksp_get_pc = petsc_call( + 'KSPGetPC', [sobjs['ksp'], Byref(sobjs['pc'])] + ) + + # Even though the default will be jacobi, set to PCNONE for now + pc_set_type = petsc_call('PCSetType', [sobjs['pc'], 'PCNONE']) + + ksp_set_from_ops = petsc_call('KSPSetFromOptions', [sobjs['ksp']]) + + matvec = self.cbbuilder.main_matvec_callback + matvec_operation = petsc_call( + 'MatShellSetOperation', + [sobjs['Jac'], 'MATOP_MULT', MatShellSetOp(matvec.name, void, void)] + ) + formfunc = self.cbbuilder.main_formfunc_callback + formfunc_operation = petsc_call( + 'SNESSetFunction', + [sobjs['snes'], objs['Null'], FormFunctionCallback(formfunc.name, void, void), + self.snes_ctx] + ) + + dmda_calls = self._create_dmda_calls(dmda) + + mainctx = sobjs['userctx'] + + call_struct_callback = petsc_call( + self.cbbuilder.user_struct_callback.name, [Byref(mainctx)] + ) + + # TODO: maybe don't need to explictly set this + mat_set_dm = petsc_call('MatSetDM', [sobjs['Jac'], dmda]) + + calls_set_app_ctx = petsc_call('DMSetApplicationContext', [dmda, Byref(mainctx)]) + + base_setup = dmda_calls + ( + snes_create, + snes_set_dm, + create_matrix, + snes_set_jac, + snes_set_type, + global_x, + global_b, + snes_get_ksp, + ksp_set_tols, + ksp_set_type, + ksp_get_pc, + pc_set_type, + ksp_set_from_ops, + matvec_operation, + formfunc_operation, + call_struct_callback, + mat_set_dm, + calls_set_app_ctx, + BlankLine + ) + extended_setup = self._extend_setup() + return base_setup + extended_setup + + def _extend_setup(self): + """ + Hook for subclasses to add additional setup calls. + """ + return () + + def _create_dmda_calls(self, dmda): + dmda_create = self._create_dmda(dmda) + dm_setup = petsc_call('DMSetUp', [dmda]) + dm_mat_type = petsc_call('DMSetMatType', [dmda, 'MATSHELL']) + return dmda_create, dm_setup, dm_mat_type + + def _create_dmda(self, dmda): + objs = self.objs + grid = self.fielddata.grid + nspace_dims = len(grid.dimensions) + + # MPI communicator + args = [objs['comm']] + + # Type of ghost nodes + args.extend(['DM_BOUNDARY_GHOSTED' for _ in range(nspace_dims)]) + + # Stencil type + if nspace_dims > 1: + args.append('DMDA_STENCIL_BOX') + + # Global dimensions + args.extend(list(grid.shape)[::-1]) + # No.of processors in each dimension + if nspace_dims > 1: + args.extend(list(grid.distributor.topology)[::-1]) + + # Number of degrees of freedom per node + args.append(dmda.dofs) + # "Stencil width" -> size of overlap + stencil_width = self.fielddata.space_order + args.append(stencil_width) + args.extend([objs['Null']]*nspace_dims) + + # The distributed array object + args.append(Byref(dmda)) + + # The PETSc call used to create the DMDA + dmda = petsc_call(f'DMDACreate{nspace_dims}d', args) + + return dmda + + +class CoupledSetup(BaseSetup): + @property + def snes_ctx(self): + return Byref(self.solver_objs['jacctx']) + + def _extend_setup(self): + objs = self.objs + sobjs = self.solver_objs + + dmda = sobjs['dmda'] + create_field_decomp = petsc_call( + 'DMCreateFieldDecomposition', + [dmda, Byref(sobjs['nfields']), objs['Null'], Byref(sobjs['fields']), + Byref(sobjs['subdms'])] + ) + submat_cb = self.cbbuilder.submatrices_callback + matop_create_submats_op = petsc_call( + 'MatShellSetOperation', + [sobjs['Jac'], 'MATOP_CREATE_SUBMATRICES', + MatShellSetOp(submat_cb.name, void, void)] + ) + + call_coupled_struct_callback = petsc_call( + 'PopulateMatContext', + [Byref(sobjs['jacctx']), sobjs['subdms'], sobjs['fields']] + ) + + shell_set_ctx = petsc_call( + 'MatShellSetContext', [sobjs['Jac'], Byref(sobjs['jacctx']._C_symbol)] + ) + + create_submats = petsc_call( + 'MatCreateSubMatrices', + [sobjs['Jac'], sobjs['nfields'], sobjs['fields'], + sobjs['fields'], 'MAT_INITIAL_MATRIX', + Byref(FieldFromComposite(objs['Submats'].base, sobjs['jacctx']))] + ) + + targets = self.fielddata.targets + + deref_dms = [ + DummyExpr(sobjs[f'da{t.name}'], sobjs['subdms'].indexed[i]) + for i, t in enumerate(targets) + ] + + xglobals = [petsc_call( + 'DMCreateGlobalVector', + [sobjs[f'da{t.name}'], Byref(sobjs[f'xglobal{t.name}'])] + ) for t in targets] + + bglobals = [petsc_call( + 'DMCreateGlobalVector', + [sobjs[f'da{t.name}'], Byref(sobjs[f'bglobal{t.name}'])] + ) for t in targets] + + return ( + create_field_decomp, + matop_create_submats_op, + call_coupled_struct_callback, + shell_set_ctx, + create_submats + ) + tuple(deref_dms) + tuple(xglobals) + tuple(bglobals) + + +class Solver: + def __init__(self, **kwargs): + self.injectsolve = kwargs.get('injectsolve') + self.objs = kwargs.get('objs') + self.solver_objs = kwargs.get('solver_objs') + self.iters = kwargs.get('iters') + self.cbbuilder = kwargs.get('cbbuilder') + self.timedep = kwargs.get('timedep') + # TODO: Should/could _execute_solve be a cached_property? + self.calls = self._execute_solve() + + def _execute_solve(self): + """ + Assigns the required time iterators to the struct and executes + the necessary calls to execute the SNES solver. + """ + sobjs = self.solver_objs + target = self.injectsolve.expr.rhs.fielddata.target + + struct_assignment = self.timedep.assign_time_iters(sobjs['userctx']) + + rhs_callback = self.cbbuilder.formrhs[0] + + dmda = sobjs['dmda'] + + rhs_call = petsc_call(rhs_callback.name, [sobjs['dmda'], sobjs['bglobal']]) + + local_x = petsc_call('DMCreateLocalVector', + [dmda, Byref(sobjs['xlocal'])]) + + vec_replace_array = self.timedep.replace_array(target) + + dm_local_to_global_x = petsc_call( + 'DMLocalToGlobal', [dmda, sobjs['xlocal'], insert_vals, + sobjs['xglobal']] + ) + + snes_solve = petsc_call('SNESSolve', [ + sobjs['snes'], sobjs['bglobal'], sobjs['xglobal']] + ) + + dm_global_to_local_x = petsc_call('DMGlobalToLocal', [ + dmda, sobjs['xglobal'], insert_vals, sobjs['xlocal']] + ) + + run_solver_calls = (struct_assignment,) + ( + rhs_call, + local_x + ) + vec_replace_array + ( + dm_local_to_global_x, + snes_solve, + dm_global_to_local_x, + BlankLine, + ) + return List(body=run_solver_calls) + + @cached_property + def spatial_body(self): + spatial_body = [] + # TODO: remove the iters[0] + for tree in retrieve_iteration_tree(self.iters[0]): + root = filter_iterations(tree, key=lambda i: i.dim.is_Space)[0] + if self.injectsolve in FindNodes(PetscMetaData).visit(root): + spatial_body.append(root) + spatial_body, = spatial_body + return spatial_body + + +class CoupledSolver(Solver): + def _execute_solve(self): + """ + Assigns the required time iterators to the struct and executes + the necessary calls to execute the SNES solver. + """ + sobjs = self.solver_objs + + struct_assignment = self.timedep.assign_time_iters(sobjs['userctx']) + + rhs_callbacks = self.cbbuilder.formrhs + + xglob = sobjs['xglobal'] + bglob = sobjs['bglobal'] + + targets = self.injectsolve.expr.rhs.fielddata.targets + + # TODO: optimise the ccode generated here + pre_solve = () + post_solve = () + + for i, (c, t) in enumerate(zip(rhs_callbacks, targets)): + name = t.name + dm = sobjs[f'da{name}'] + target_xloc = sobjs[f'xlocal{name}'] + target_xglob = sobjs[f'xglobal{name}'] + target_bglob = sobjs[f'bglobal{name}'] + field = sobjs['fields'].indexed[i] + s = sobjs[f'scatter{name}'] + + pre_solve += ( + petsc_call(c.name, [dm, target_bglob]), + petsc_call('DMCreateLocalVector', [dm, Byref(target_xloc)]), + self.timedep.replace_array(t), + petsc_call( + 'DMLocalToGlobal', + [dm, target_xloc, insert_vals, target_xglob] + ), + petsc_call( + 'VecScatterCreate', + [xglob, field, target_xglob, self.objs['Null'], Byref(s)] + ), + petsc_call( + 'VecScatterBegin', + [s, target_xglob, xglob, insert_vals, sreverse] + ), + petsc_call( + 'VecScatterEnd', + [s, target_xglob, xglob, insert_vals, sreverse] + ), + petsc_call( + 'VecScatterBegin', + [s, target_bglob, bglob, insert_vals, sreverse] + ), + petsc_call( + 'VecScatterEnd', + [s, target_bglob, bglob, insert_vals, sreverse] + ), + ) + + post_solve += ( + petsc_call( + 'VecScatterBegin', + [s, xglob, target_xglob, insert_vals, sforward] + ), + petsc_call( + 'VecScatterEnd', + [s, xglob, target_xglob, insert_vals, sforward] + ), + petsc_call( + 'DMGlobalToLocal', + [dm, target_xglob, insert_vals, target_xloc] + ) + ) + + snes_solve = (petsc_call('SNESSolve', [sobjs['snes'], bglob, xglob]),) + + return List( + body=( + (struct_assignment,) + + pre_solve + + snes_solve + + post_solve + + (BlankLine,) + ) + ) + + +class NonTimeDependent: + def __init__(self, **kwargs): + self.injectsolve = kwargs.get('injectsolve') + self.iters = kwargs.get('iters') + self.sobjs = kwargs.get('solver_objs') + self.kwargs = kwargs + self.origin_to_moddim = self._origin_to_moddim_mapper(self.iters) + self.time_idx_to_symb = self.injectsolve.expr.rhs.time_mapper + + def _origin_to_moddim_mapper(self, iters): + return {} + + def uxreplace_time(self, body): + return body + + def replace_array(self, target): + """ + VecReplaceArray() is a PETSc function that allows replacing the array + of a `Vec` with a user provided array. + https://petsc.org/release/manualpages/Vec/VecReplaceArray/ + + This function is used to replace the array of the PETSc solution `Vec` + with the array from the `Function` object representing the target. + + Examples + -------- + >>> target + f1(x, y) + >>> call = replace_array(target) + >>> print(call) + PetscCall(VecReplaceArray(xlocal0,f1_vec->data)); + """ + sobjs = self.sobjs + + field_from_ptr = FieldFromPointer( + target.function._C_field_data, target.function._C_symbol + ) + xlocal = sobjs.get(f'xlocal{target.name}', sobjs['xlocal']) + return (petsc_call('VecReplaceArray', [xlocal, field_from_ptr]),) + + def assign_time_iters(self, struct): + return [] + + +class TimeDependent(NonTimeDependent): + """ + A class for managing time-dependent solvers. + + This includes scenarios where the target is not directly a `TimeFunction`, + but depends on other functions that are. + + Outline of time loop abstraction with PETSc: + + - At PETScSolve, time indices are replaced with temporary `Symbol` objects + via a mapper (e.g., {t: tau0, t + dt: tau1}) to prevent the time loop + from being generated in the callback functions. These callbacks, needed + for each `SNESSolve` at every time step, don't require the time loop, but + may still need access to data from other time steps. + - All `Function` objects are passed through the initial lowering via the + `LinearSolveExpr` object, ensuring the correct time loop is generated + in the main kernel. + - Another mapper is created based on the modulo dimensions + generated by the `LinearSolveExpr` object in the main kernel + (e.g., {time: time, t: t0, t + 1: t1}). + - These two mappers are used to generate a final mapper `symb_to_moddim` + (e.g. {tau0: t0, tau1: t1}) which is used at the IET level to + replace the temporary `Symbol` objects in the callback functions with + the correct modulo dimensions. + - Modulo dimensions are updated in the matrix context struct at each time + step and can be accessed in the callback functions where needed. + """ + @property + def time_spacing(self): + return self.injectsolve.expr.rhs.grid.stepping_dim.spacing + + @cached_property + def symb_to_moddim(self): + """ + Maps temporary `Symbol` objects created during `PETScSolve` to their + corresponding modulo dimensions (e.g. creates {tau0: t0, tau1: t1}). + """ + mapper = { + v: k.xreplace({self.time_spacing: 1, -self.time_spacing: -1}) + for k, v in self.time_idx_to_symb.items() + } + return {symb: self.origin_to_moddim[mapper[symb]] for symb in mapper} + + def is_target_time(self, target): + return any(i.is_Time for i in target.dimensions) + + def target_time(self, target): + target_time = [ + i for i, d in zip(target.indices, target.dimensions) + if d.is_Time + ] + assert len(target_time) == 1 + target_time = target_time.pop() + return target_time + + def uxreplace_time(self, body): + return Uxreplace(self.symb_to_moddim).visit(body) + + def _origin_to_moddim_mapper(self, iters): + """ + Creates a mapper of the origin of the time dimensions to their corresponding + modulo dimensions from a list of `Iteration` objects. + + Examples + -------- + >>> iters + (, + ) + >>> _origin_to_moddim_mapper(iters) + {time: time, t: t0, t + 1: t1} + """ + time_iter = [i for i in iters if any(d.is_Time for d in i.dimensions)] + mapper = {} + + if not time_iter: + return mapper + + for i in time_iter: + for d in i.dimensions: + if d.is_Modulo: + mapper[d.origin] = d + elif d.is_Time: + mapper[d] = d + return mapper + + def replace_array(self, target): + """ + In the case that the actual target is time-dependent e.g a `TimeFunction`, + a pointer to the first element in the array that will be updated during + the time step is passed to VecReplaceArray(). + + Examples + -------- + >>> target + f1(time + dt, x, y) + >>> calls = replace_array(target) + >>> print(List(body=calls)) + PetscCall(VecGetSize(xlocal0,&(localsize0))); + float * f1_ptr0 = (time + 1)*localsize0 + (float*)(f1_vec->data); + PetscCall(VecReplaceArray(xlocal0,f1_ptr0)); + + >>> target + f1(t + dt, x, y) + >>> calls = replace_array(target) + >>> print(List(body=calls)) + PetscCall(VecGetSize(xlocal0,&(localsize0))); + float * f1_ptr0 = t1*localsize0 + (float*)(f1_vec->data); + PetscCall(VecReplaceArray(xlocal0,f1_ptr0)); + """ + sobjs = self.sobjs + + if self.is_target_time(target): + mapper = {self.time_spacing: 1, -self.time_spacing: -1} + + target_time = self.target_time(target).xreplace(mapper) + target_time = self.origin_to_moddim.get(target_time, target_time) + + xlocal = sobjs.get(f'xlocal{target.name}', sobjs['xlocal']) + start_ptr = sobjs[f'{target.name}_ptr'] + + return ( + petsc_call('VecGetSize', [xlocal, Byref(sobjs['localsize'])]), + DummyExpr( + start_ptr, + cast_mapper[(target.dtype, '*')]( + FieldFromPointer(target._C_field_data, target._C_symbol) + ) + Mul(target_time, sobjs['localsize']), + init=True + ), + petsc_call('VecReplaceArray', [xlocal, start_ptr]) + ) + return super().replace_array(target) + + def assign_time_iters(self, struct): + """ + Assign required time iterators to the struct. + These iterators are updated at each timestep in the main kernel + for use in callback functions. + + Examples + -------- + >>> struct + ctx + >>> struct.fields + [h_x, x_M, x_m, f1(t, x), t0, t1] + >>> assigned = assign_time_iters(struct) + >>> print(assigned[0]) + ctx.t0 = t0; + >>> print(assigned[1]) + ctx.t1 = t1; + """ + to_assign = [ + f for f in struct.fields if (f.is_Dimension and (f.is_Time or f.is_Modulo)) + ] + time_iter_assignments = [ + DummyExpr(FieldFromComposite(field, struct), field) + for field in to_assign + ] + return time_iter_assignments + + +void = 'void' +insert_vals = 'INSERT_VALUES' +sreverse = 'SCATTER_REVERSE' +sforward = 'SCATTER_FORWARD' diff --git a/devito/petsc/iet/utils.py b/devito/petsc/iet/utils.py new file mode 100644 index 0000000000..99da0468ad --- /dev/null +++ b/devito/petsc/iet/utils.py @@ -0,0 +1,23 @@ +from devito.petsc.iet.nodes import PetscMetaData, PETScCall +from devito.ir.equations import OpPetsc + + +def petsc_call(specific_call, call_args): + return PETScCall('PetscCall', [PETScCall(specific_call, arguments=call_args)]) + + +def petsc_call_mpi(specific_call, call_args): + return PETScCall('PetscCallMPI', [PETScCall(specific_call, arguments=call_args)]) + + +def petsc_struct(name, fields, pname, liveness='lazy', modifier=None): + # TODO: Fix this circular import + from devito.petsc.types.object import PETScStruct + return PETScStruct(name=name, pname=pname, + fields=fields, liveness=liveness, + modifier=modifier) + + +# Mapping special Eq operations to their corresponding IET Expression subclass types. +# These operations correspond to subclasses of Eq utilised within PETScSolve. +petsc_iet_mapper = {OpPetsc: PetscMetaData} diff --git a/devito/petsc/initialize.py b/devito/petsc/initialize.py new file mode 100644 index 0000000000..9126414658 --- /dev/null +++ b/devito/petsc/initialize.py @@ -0,0 +1,42 @@ +import os +import sys +from ctypes import POINTER, cast, c_char +import atexit + +from devito import Operator +from devito.types import Symbol +from devito.types.equation import PetscEq +from devito.petsc.types import Initialize, Finalize + +global _petsc_initialized +_petsc_initialized = False + + +def PetscInitialize(): + global _petsc_initialized + if not _petsc_initialized: + dummy = Symbol(name='d') + # TODO: Potentially just use cgen + the compiler machinery in Devito + # to generate these "dummy_ops" instead of using the Operator class. + # This would prevent circular imports when initializing during import + # from the PETSc module. + op_init = Operator( + [PetscEq(dummy, Initialize(dummy))], + name='kernel_init', opt='noop' + ) + op_finalize = Operator( + [PetscEq(dummy, Finalize(dummy))], + name='kernel_finalize', opt='noop' + ) + + # `argv_bytes` must be a list so the memory address persists + # `os.fsencode` should be preferred over `string().encode('utf-8')` + # in case there is some system specific encoding in use + argv_bytes = list(map(os.fsencode, sys.argv)) + argv_pointer = (POINTER(c_char)*len(sys.argv))( + *map(lambda s: cast(s, POINTER(c_char)), argv_bytes) + ) + op_init.apply(argc=len(sys.argv), argv=argv_pointer) + + atexit.register(op_finalize.apply) + _petsc_initialized = True diff --git a/devito/petsc/solve.py b/devito/petsc/solve.py new file mode 100644 index 0000000000..971fc8678b --- /dev/null +++ b/devito/petsc/solve.py @@ -0,0 +1,327 @@ +from functools import singledispatch + +import sympy + +from devito.finite_differences.differentiable import Mul +from devito.finite_differences.derivative import Derivative +from devito.types import Eq, Symbol, SteppingDimension, TimeFunction +from devito.types.equation import PetscEq +from devito.operations.solve import eval_time_derivatives +from devito.symbolics import retrieve_functions +from devito.tools import as_tuple, filter_ordered +from devito.petsc.types import (LinearSolveExpr, PETScArray, DMDALocalInfo, + FieldData, MultipleFieldData, SubMatrices) + + +__all__ = ['PETScSolve', 'EssentialBC'] + + +def PETScSolve(target_eqns, target=None, solver_parameters=None, **kwargs): + if target is not None: + return InjectSolve(solver_parameters, {target: target_eqns}).build_eq() + else: + return InjectSolveNested(solver_parameters, target_eqns).build_eq() + + +class InjectSolve: + def __init__(self, solver_parameters=None, target_eqns=None): + self.solver_params = solver_parameters + self.time_mapper = None + self.target_eqns = target_eqns + + def build_eq(self): + target, funcs, fielddata = self.linear_solve_args() + # Placeholder equation for inserting calls to the solver + linear_solve = LinearSolveExpr( + funcs, + self.solver_params, + fielddata=fielddata, + time_mapper=self.time_mapper, + localinfo=localinfo + ) + return [PetscEq(target, linear_solve)] + + def linear_solve_args(self): + target, eqns = next(iter(self.target_eqns.items())) + eqns = as_tuple(eqns) + + funcs = get_funcs(eqns) + self.time_mapper = generate_time_mapper(funcs) + arrays = self.generate_arrays(target) + + return target, tuple(funcs), self.generate_field_data(eqns, target, arrays) + + def generate_field_data(self, eqns, target, arrays): + formfuncs, formrhs = zip( + *[self.build_function_eqns(eq, target, arrays) for eq in eqns] + ) + matvecs = [self.build_matvec_eqns(eq, target, arrays) for eq in eqns] + + return FieldData( + target=target, + matvecs=matvecs, + formfuncs=formfuncs, + formrhs=formrhs, + arrays=arrays + ) + + def build_function_eqns(self, eq, target, arrays): + b, F_target, targets = separate_eqn(eq, target) + formfunc = self.make_formfunc(eq, F_target, arrays, targets) + formrhs = self.make_rhs(eq, b, arrays) + + return tuple(expr.subs(self.time_mapper) for expr in (formfunc, formrhs)) + + def build_matvec_eqns(self, eq, target, arrays): + b, F_target, targets = separate_eqn(eq, target) + if not F_target: + return None + matvec = self.make_matvec(eq, F_target, arrays, targets) + return matvec.subs(self.time_mapper) + + def make_matvec(self, eq, F_target, arrays, targets): + rhs = arrays['x'] if isinstance(eq, EssentialBC) else F_target.subs( + targets_to_arrays(arrays['x'], targets) + ) + return Eq(arrays['y'], rhs, subdomain=eq.subdomain) + + def make_formfunc(self, eq, F_target, arrays, targets): + rhs = 0. if isinstance(eq, EssentialBC) else F_target.subs( + targets_to_arrays(arrays['x'], targets) + ) + return Eq(arrays['f'], rhs, subdomain=eq.subdomain) + + def make_rhs(self, eq, b, arrays): + rhs = 0. if isinstance(eq, EssentialBC) else b + return Eq(arrays['b'], rhs, subdomain=eq.subdomain) + + def generate_arrays(self, target): + return { + p: PETScArray(name=f'{p}_{target.name}', + target=target, + liveness='eager', + localinfo=localinfo) + for p in prefixes + } + + +class InjectSolveNested(InjectSolve): + def linear_solve_args(self): + combined_eqns = [] + for eqns in self.target_eqns.values(): + combined_eqns.extend(eqns) + funcs = get_funcs(combined_eqns) + self.time_mapper = generate_time_mapper(funcs) + + targets = list(self.target_eqns.keys()) + jacobian = SubMatrices(targets) + + all_data = MultipleFieldData(jacobian) + + for target, eqns in self.target_eqns.items(): + eqns = as_tuple(eqns) + arrays = self.generate_arrays(target) + + self.update_jacobian(eqns, target, jacobian, arrays) + + fielddata = self.generate_field_data( + eqns, target, arrays + ) + all_data.add_field_data(fielddata) + + return target, tuple(funcs), all_data + + def update_jacobian(self, eqns, target, jacobian, arrays): + for submat, mtvs in jacobian.submatrices[target].items(): + matvecs = [ + self.build_matvec_eqns(eq, mtvs['derivative_wrt'], arrays) + for eq in eqns + ] + # Set submatrix only if there's at least one non-zero matvec + if any(m is not None for m in matvecs): + jacobian.set_submatrix(target, submat, matvecs) + + def generate_field_data(self, eqns, target, arrays): + formfuncs, formrhs = zip( + *[self.build_function_eqns(eq, target, arrays) for eq in eqns] + ) + + return FieldData( + target=target, + formfuncs=formfuncs, + formrhs=formrhs, + arrays=arrays + ) + + +class EssentialBC(Eq): + pass + + +def separate_eqn(eqn, target): + """ + Separate the equation into two separate expressions, + where F(target) = b. + """ + zeroed_eqn = Eq(eqn.lhs - eqn.rhs, 0) + zeroed_eqn = eval_time_derivatives(zeroed_eqn.lhs) + target_funcs = set(generate_targets(zeroed_eqn, target)) + b, F_target = remove_targets(zeroed_eqn, target_funcs) + return -b, F_target, target_funcs + + +def generate_targets(eq, target): + """ + Extract all the functions that share the same time index as the target + but may have different spatial indices. + """ + funcs = retrieve_functions(eq) + if isinstance(target, TimeFunction): + time_idx = target.indices[target.time_dim] + targets = [ + f for f in funcs if f.function is target.function and time_idx + in f.indices + ] + else: + targets = [f for f in funcs if f.function is target.function] + return targets + + +def targets_to_arrays(array, targets): + """ + Map each target in `targets` to a corresponding array generated from `array`, + matching the spatial indices of the target. + Example: + -------- + >>> array + vec_u(x, y) + >>> targets + {u(t + dt, x + h_x, y), u(t + dt, x - h_x, y), u(t + dt, x, y)} + >>> targets_to_arrays(array, targets) + {u(t + dt, x - h_x, y): vec_u(x - h_x, y), + u(t + dt, x + h_x, y): vec_u(x + h_x, y), + u(t + dt, x, y): vec_u(x, y)} + """ + space_indices = [ + tuple(f.indices[d] for d in f.space_dimensions) for f in targets + ] + array_targets = [ + array.subs(dict(zip(array.indices, i))) for i in space_indices + ] + return dict(zip(targets, array_targets)) + + +@singledispatch +def remove_targets(expr, targets): + return (0, expr) if expr in targets else (expr, 0) + + +@remove_targets.register(sympy.Add) +def _(expr, targets): + if not any(expr.has(t) for t in targets): + return (expr, 0) + + args_b, args_F = zip(*(remove_targets(a, targets) for a in expr.args)) + return (expr.func(*args_b, evaluate=False), expr.func(*args_F, evaluate=False)) + + +@remove_targets.register(Mul) +def _(expr, targets): + if not any(expr.has(t) for t in targets): + return (expr, 0) + + args_b, args_F = zip(*[remove_targets(a, targets) if any(a.has(t) for t in targets) + else (a, a) for a in expr.args]) + return (expr.func(*args_b, evaluate=False), expr.func(*args_F, evaluate=False)) + + +@remove_targets.register(Derivative) +def _(expr, targets): + return (0, expr) if any(expr.has(t) for t in targets) else (expr, 0) + + +@singledispatch +def centre_stencil(expr, target): + """ + Extract the centre stencil from an expression. Its coefficient is what + would appear on the diagonal of the matrix system if the matrix were + formed explicitly. + """ + return expr if expr == target else 0 + + +@centre_stencil.register(sympy.Add) +def _(expr, target): + if not expr.has(target): + return 0 + + args = [centre_stencil(a, target) for a in expr.args] + return expr.func(*args, evaluate=False) + + +@centre_stencil.register(Mul) +def _(expr, target): + if not expr.has(target): + return 0 + + args = [] + for a in expr.args: + if not a.has(target): + args.append(a) + else: + args.append(centre_stencil(a, target)) + + return expr.func(*args, evaluate=False) + + +@centre_stencil.register(Derivative) +def _(expr, target): + if not expr.has(target): + return 0 + args = [centre_stencil(a, target) for a in expr.evaluate.args] + return expr.evaluate.func(*args) + + +def generate_time_mapper(funcs): + """ + Replace time indices with `Symbols` in equations used within + PETSc callback functions. These symbols are Uxreplaced at the IET + level to align with the `TimeDimension` and `ModuloDimension` objects + present in the initial lowering. + NOTE: All functions used in PETSc callback functions are attached to + the `LinearSolveExpr` object, which is passed through the initial lowering + (and subsequently dropped and replaced with calls to run the solver). + Therefore, the appropriate time loop will always be correctly generated inside + the main kernel. + Examples + -------- + >>> funcs = [ + >>> f1(t + dt, x, y), + >>> g1(t + dt, x, y), + >>> g2(t, x, y), + >>> f1(t, x, y) + >>> ] + >>> generate_time_mapper(funcs) + {t + dt: tau0, t: tau1} + """ + time_indices = list({ + i if isinstance(d, SteppingDimension) else d + for f in funcs + for i, d in zip(f.indices, f.dimensions) + if d.is_Time + }) + tau_symbs = [Symbol('tau%d' % i) for i in range(len(time_indices))] + return dict(zip(time_indices, tau_symbs)) + + +def get_funcs(eqns): + funcs = [ + func + for eq in eqns + for func in retrieve_functions(eval_time_derivatives(eq.lhs - eq.rhs)) + ] + return filter_ordered(funcs) + + +localinfo = DMDALocalInfo(name='info', liveness='eager') +prefixes = ['y', 'x', 'f', 'b'] diff --git a/devito/petsc/types/__init__.py b/devito/petsc/types/__init__.py new file mode 100644 index 0000000000..ebcceb8d45 --- /dev/null +++ b/devito/petsc/types/__init__.py @@ -0,0 +1,3 @@ +from .array import * # noqa +from .types import * # noqa +from .object import * # noqa diff --git a/devito/petsc/types/array.py b/devito/petsc/types/array.py new file mode 100644 index 0000000000..381b1af121 --- /dev/null +++ b/devito/petsc/types/array.py @@ -0,0 +1,119 @@ +from functools import cached_property + +from devito.types.utils import DimensionTuple +from devito.types.array import ArrayBasic +from devito.finite_differences import Differentiable +from devito.types.basic import AbstractFunction +from devito.finite_differences.tools import fd_weights_registry +from devito.tools import as_tuple, CustomDtype +from devito.symbolics import FieldFromComposite + + +class PETScArray(ArrayBasic, Differentiable): + """ + PETScArrays are generated by the compiler only and represent + a customised variant of ArrayBasic. + Differentiable enables compatibility with standard Function objects, + allowing for the use of the `subs` method. + + PETScArray objects represent vector objects within PETSc. + They correspond to the spatial domain of a Function-like object + provided by the user, which is passed to PETScSolve as the target. + + TODO: Potentially re-evaluate and separate into PETScFunction(Differentiable) + and then PETScArray(ArrayBasic). + """ + + _data_alignment = False + + # Default method for the finite difference approximation weights computation. + _default_fd = 'taylor' + + __rkwargs__ = (AbstractFunction.__rkwargs__ + + ('target', 'liveness', 'coefficients', 'localinfo')) + + def __init_finalize__(self, *args, **kwargs): + + self._target = kwargs.get('target') + self._ndim = kwargs['ndim'] = len(self._target.space_dimensions) + self._dimensions = kwargs['dimensions'] = self._target.space_dimensions + + super().__init_finalize__(*args, **kwargs) + + # Symbolic (finite difference) coefficients + self._coefficients = kwargs.get('coefficients', self._default_fd) + if self._coefficients not in fd_weights_registry: + raise ValueError("coefficients must be one of %s" + " not %s" % (str(fd_weights_registry), self._coefficients)) + + self._localinfo = kwargs.get('localinfo', None) + + @property + def ndim(self): + return self._ndim + + @classmethod + def __dtype_setup__(cls, **kwargs): + return kwargs['target'].dtype + + @classmethod + def __indices_setup__(cls, *args, **kwargs): + dimensions = kwargs['target'].space_dimensions + if args: + indices = args + else: + indices = dimensions + return as_tuple(dimensions), as_tuple(indices) + + def __halo_setup__(self, **kwargs): + target = kwargs['target'] + halo = [target.halo[d] for d in target.space_dimensions] + return DimensionTuple(*halo, getters=target.space_dimensions) + + @property + def dimensions(self): + return self._dimensions + + @property + def target(self): + return self._target + + @property + def coefficients(self): + """Form of the coefficients of the function.""" + return self._coefficients + + @property + def shape(self): + return self.target.grid.shape + + @property + def space_order(self): + return self.target.space_order + + @property + def localinfo(self): + return self._localinfo + + @cached_property + def _shape_with_inhalo(self): + return self.target.shape_with_inhalo + + @cached_property + def shape_allocated(self): + return self.target.shape_allocated + + @cached_property + def _C_ctype(self): + return CustomDtype('PetscScalar', modifier=' *') + + @property + def symbolic_shape(self): + field_from_composites = [ + FieldFromComposite('g%sm' % d.name, self.localinfo) for d in self.dimensions] + # Reverse it since DMDA is setup backwards to Devito dimensions. + return DimensionTuple(*field_from_composites[::-1], getters=self.dimensions) + + @property + def _restrict_keyword(self): + return '' diff --git a/devito/petsc/types/macros.py b/devito/petsc/types/macros.py new file mode 100644 index 0000000000..4355535e64 --- /dev/null +++ b/devito/petsc/types/macros.py @@ -0,0 +1,5 @@ +import cgen as c + + +# TODO: Don't use c.Line here? +petsc_func_begin_user = c.Line('PetscFunctionBeginUser;') diff --git a/devito/petsc/types/object.py b/devito/petsc/types/object.py new file mode 100644 index 0000000000..9acf7def46 --- /dev/null +++ b/devito/petsc/types/object.py @@ -0,0 +1,326 @@ +from ctypes import POINTER, c_char +from devito.tools import CustomDtype, dtype_to_cstr, as_tuple, CustomIntType +from devito.types import (LocalObject, LocalCompositeObject, ModuloDimension, + TimeDimension, ArrayObject, CustomDimension) +from devito.symbolics import Byref, Cast +from devito.types.basic import DataSymbol +from devito.petsc.iet.utils import petsc_call + + +class CallbackDM(LocalObject): + """ + PETSc Data Management object (DM). This is the DM instance + accessed within the callback functions via `SNESGetDM` and + is not destroyed during callback execution. + """ + dtype = CustomDtype('DM') + + +class DM(LocalObject): + """ + PETSc Data Management object (DM). This is the primary DM instance + created within the main kernel and linked to the SNES + solver using `SNESSetDM`. + """ + dtype = CustomDtype('DM') + + def __init__(self, *args, dofs=1, **kwargs): + super().__init__(*args, **kwargs) + self._dofs = dofs + + @property + def dofs(self): + return self._dofs + + @property + def _C_free(self): + return petsc_call('DMDestroy', [Byref(self.function)]) + + # TODO: This is growing out of hand so switch to an enumeration or something? + @property + def _C_free_priority(self): + return 4 + + +class DMCast(Cast): + _base_typ = 'DM' + + +class CallbackMat(LocalObject): + """ + PETSc Matrix object (Mat) used within callback functions. + These instances are not destroyed during callback execution; + instead, they are managed and destroyed in the main kernel. + """ + dtype = CustomDtype('Mat') + + +class Mat(LocalObject): + dtype = CustomDtype('Mat') + + @property + def _C_free(self): + return petsc_call('MatDestroy', [Byref(self.function)]) + + @property + def _C_free_priority(self): + return 2 + + +class LocalVec(LocalObject): + """ + PETSc local vector object (Vec). + A local vector has ghost locations that contain values that are + owned by other MPI ranks. + """ + dtype = CustomDtype('Vec') + + +class CallbackGlobalVec(LocalVec): + """ + PETSc global vector object (Vec). For example, used for coupled + solves inside the `WholeFormFunc` callback. + """ + + +class GlobalVec(LocalVec): + """ + PETSc global vector object (Vec). + A global vector is a parallel vector that has no duplicate values + between MPI ranks. A global vector has no ghost locations. + """ + @property + def _C_free(self): + return petsc_call('VecDestroy', [Byref(self.function)]) + + @property + def _C_free_priority(self): + return 1 + + +class PetscMPIInt(LocalObject): + """ + PETSc datatype used to represent `int` parameters + to MPI functions. + """ + dtype = CustomDtype('PetscMPIInt') + + +class PetscInt(LocalObject): + """ + PETSc datatype used to represent `int` parameters + to PETSc functions. + """ + dtype = CustomIntType('PetscInt') + + +class KSP(LocalObject): + """ + PETSc KSP : Linear Systems Solvers. + Manages Krylov Methods. + """ + dtype = CustomDtype('KSP') + + +class CallbackSNES(LocalObject): + """ + PETSc SNES : Non-Linear Systems Solvers. + """ + dtype = CustomDtype('SNES') + + +class SNES(CallbackSNES): + @property + def _C_free(self): + return petsc_call('SNESDestroy', [Byref(self.function)]) + + @property + def _C_free_priority(self): + return 3 + + +class PC(LocalObject): + """ + PETSc object that manages all preconditioners (PC). + """ + dtype = CustomDtype('PC') + + +class KSPConvergedReason(LocalObject): + """ + PETSc object - reason a Krylov method was determined + to have converged or diverged. + """ + dtype = CustomDtype('KSPConvergedReason') + + +class DMDALocalInfo(LocalObject): + """ + PETSc object - C struct containing information + about the local grid. + """ + dtype = CustomDtype('DMDALocalInfo') + + +class PetscErrorCode(LocalObject): + """ + PETSc datatype used to return PETSc error codes. + https://petsc.org/release/manualpages/Sys/PetscErrorCode/ + """ + dtype = CustomDtype('PetscErrorCode') + + +class DummyArg(LocalObject): + """ + A void pointer used to satisfy the function + signature of the `FormFunction` callback. + """ + dtype = CustomDtype('void', modifier='*') + + +class MatReuse(LocalObject): + dtype = CustomDtype('MatReuse') + + +class VecScatter(LocalObject): + dtype = CustomDtype('VecScatter') + + +class StartPtr(LocalObject): + def __init__(self, name, dtype): + super().__init__(name=name) + self.dtype = CustomDtype(dtype_to_cstr(dtype), modifier=' *') + + +class SingleIS(LocalObject): + dtype = CustomDtype('IS') + + +class PETScStruct(LocalCompositeObject): + + @property + def time_dim_fields(self): + """ + Fields within the struct that are updated during the time loop. + These are not set in the `PopulateMatContext` callback. + """ + return [f for f in self.fields + if isinstance(f, (ModuloDimension, TimeDimension))] + + @property + def callback_fields(self): + """ + Fields within the struct that are initialized in the `PopulateMatContext` + callback. These fields are not updated in the time loop. + """ + return [f for f in self.fields if f not in self.time_dim_fields] + + _C_modifier = ' *' + + +class JacobianStruct(PETScStruct): + def __init__(self, name='jctx', pname='JacobianCtx', fields=None, + modifier='', liveness='lazy'): + super().__init__(name, pname, fields, modifier, liveness) + _C_modifier = None + + +class SubMatrixStruct(PETScStruct): + def __init__(self, name='subctx', pname='SubMatrixCtx', fields=None, + modifier=' *', liveness='lazy'): + super().__init__(name, pname, fields, modifier, liveness) + _C_modifier = None + + +class JacobianStructCast(Cast): + _base_typ = 'struct JacobianCtx *' + + +class PETScArrayObject(ArrayObject): + _data_alignment = False + + def __init_finalize__(self, *args, **kwargs): + self._nindices = kwargs.pop('nindices', 1) + super().__init_finalize__(*args, **kwargs) + + @classmethod + def __indices_setup__(cls, **kwargs): + try: + return as_tuple(kwargs['dimensions']), as_tuple(kwargs['dimensions']) + except KeyError: + nindices = kwargs.get('nindices', 1) + dim = CustomDimension(name='d', symbolic_size=nindices) + return (dim,), (dim,) + + @property + def dim(self): + assert len(self.dimensions) == 1 + return self.dimensions[0] + + @property + def nindices(self): + return self._nindices + + @property + def _C_name(self): + return self.name + + @property + def _mem_stack(self): + return False + + @property + def _C_free_priority(self): + return 0 + + +class CallbackPointerIS(PETScArrayObject): + """ + Index set object used for efficient indexing into vectors and matrices. + https://petsc.org/release/manualpages/IS/IS/ + """ + @property + def dtype(self): + return CustomDtype('IS', modifier=' *') + + +class PointerIS(CallbackPointerIS): + @property + def _C_free(self): + destroy_calls = [ + petsc_call('ISDestroy', [Byref(self.indexify().subs({self.dim: i}))]) + for i in range(self._nindices) + ] + destroy_calls.append(petsc_call('PetscFree', [self.function])) + return destroy_calls + + +class CallbackPointerDM(PETScArrayObject): + @property + def dtype(self): + return CustomDtype('DM', modifier=' *') + + +class PointerDM(CallbackPointerDM): + @property + def _C_free(self): + destroy_calls = [ + petsc_call('DMDestroy', [Byref(self.indexify().subs({self.dim: i}))]) + for i in range(self._nindices) + ] + destroy_calls.append(petsc_call('PetscFree', [self.function])) + return destroy_calls + + +class PointerMat(PETScArrayObject): + _C_modifier = ' *' + + @property + def dtype(self): + return CustomDtype('Mat', modifier=' *') + + +class ArgvSymbol(DataSymbol): + @property + def _C_ctype(self): + return POINTER(POINTER(c_char)) diff --git a/devito/petsc/types/types.py b/devito/petsc/types/types.py new file mode 100644 index 0000000000..4964b72e20 --- /dev/null +++ b/devito/petsc/types/types.py @@ -0,0 +1,327 @@ +import sympy + +from devito.tools import Reconstructable, sympy_mutex +from devito.tools.dtypes_lowering import mapper +from devito.petsc.utils import get_petsc_precision + + +class MetaData(sympy.Function, Reconstructable): + def __new__(cls, expr, **kwargs): + with sympy_mutex: + obj = sympy.Function.__new__(cls, expr) + obj._expr = expr + return obj + + @property + def expr(self): + return self._expr + + +class Initialize(MetaData): + pass + + +class Finalize(MetaData): + pass + + +class LinearSolveExpr(MetaData): + """ + A symbolic expression passed through the Operator, containing the metadata + needed to execute a linear solver. Linear problems are handled with + `SNESSetType(snes, KSPONLY)`, enabling a unified interface for both + linear and nonlinear solvers. + # TODO: extend this + defaults: + - 'ksp_type': String with the name of the PETSc Krylov method. + Default is 'gmres' (Generalized Minimal Residual Method). + https://petsc.org/main/manualpages/KSP/KSPType/ + - 'pc_type': String with the name of the PETSc preconditioner. + Default is 'jacobi' (i.e diagonal scaling preconditioning). + https://petsc.org/main/manualpages/PC/PCType/ + KSP tolerances: + https://petsc.org/release/manualpages/KSP/KSPSetTolerances/ + - 'ksp_rtol': Relative convergence tolerance. Default + is 1e-5. + - 'ksp_atol': Absolute convergence for tolerance. Default + is 1e-50. + - 'ksp_divtol': Divergence tolerance, amount residual norm can + increase before `KSPConvergedDefault()` concludes + that the method is diverging. Default is 1e5. + - 'ksp_max_it': Maximum number of iterations to use. Default + is 1e4. + """ + + __rargs__ = ('expr',) + __rkwargs__ = ('solver_parameters', 'fielddata', 'time_mapper', + 'localinfo') + + defaults = { + 'ksp_type': 'gmres', + 'pc_type': 'jacobi', + 'ksp_rtol': 1e-5, # Relative tolerance + 'ksp_atol': 1e-50, # Absolute tolerance + 'ksp_divtol': 1e5, # Divergence tolerance + 'ksp_max_it': 1e4 # Maximum iterations + } + + def __new__(cls, expr, solver_parameters=None, + fielddata=None, time_mapper=None, localinfo=None, **kwargs): + + if solver_parameters is None: + solver_parameters = cls.defaults + else: + for key, val in cls.defaults.items(): + solver_parameters[key] = solver_parameters.get(key, val) + + with sympy_mutex: + obj = sympy.Function.__new__(cls, expr) + + obj._expr = expr + obj._solver_parameters = solver_parameters + obj._fielddata = fielddata if fielddata else FieldData() + obj._time_mapper = time_mapper + obj._localinfo = localinfo + return obj + + def __repr__(self): + return "%s(%s)" % (self.__class__.__name__, self.expr) + + __str__ = __repr__ + + def _sympystr(self, printer): + return str(self) + + def __hash__(self): + return hash(self.expr) + + def __eq__(self, other): + return (isinstance(other, LinearSolveExpr) and + self.expr == other.expr) + + @property + def expr(self): + return self._expr + + @property + def fielddata(self): + return self._fielddata + + @property + def solver_parameters(self): + return self._solver_parameters + + @property + def time_mapper(self): + return self._time_mapper + + @property + def localinfo(self): + return self._localinfo + + @property + def grid(self): + return self.fielddata.grid + + @classmethod + def eval(cls, *args): + return None + + func = Reconstructable._rebuild + + +class FieldData: + def __init__(self, target=None, matvecs=None, formfuncs=None, formrhs=None, + arrays=None, **kwargs): + self._target = kwargs.get('target', target) + + petsc_precision = mapper[get_petsc_precision()] + if self._target.dtype != petsc_precision: + raise TypeError( + f"Your target dtype must match the precision of your " + f"PETSc configuration. " + f"Expected {petsc_precision}, but got {self._target.dtype}." + ) + self._matvecs = matvecs + self._formfuncs = formfuncs + self._formrhs = formrhs + self._arrays = arrays + + @property + def target(self): + return self._target + + @property + def matvecs(self): + return self._matvecs + + @property + def formfuncs(self): + return self._formfuncs + + @property + def formrhs(self): + return self._formrhs + + @property + def arrays(self): + return self._arrays + + @property + def space_dimensions(self): + return self.target.space_dimensions + + @property + def grid(self): + return self.target.grid + + @property + def space_order(self): + return self.target.space_order + + @property + def targets(self): + return (self.target,) + + +class MultipleFieldData(FieldData): + def __init__(self, submatrices=None): + self.field_data_list = [] + self._submatrices = submatrices + + def add_field_data(self, field_data): + self.field_data_list.append(field_data) + + def get_field_data(self, target): + for field_data in self.field_data_list: + if field_data.target == target: + return field_data + raise ValueError(f"FieldData with target {target} not found.") + pass + + @property + def target(self): + return None + + @property + def targets(self): + return tuple(field_data.target for field_data in self.field_data_list) + + @property + def space_dimensions(self): + space_dims = {field_data.space_dimensions for field_data in self.field_data_list} + if len(space_dims) > 1: + # TODO: This may not actually have to be the case, but enforcing it for now + raise ValueError( + "All targets within a PETScSolve have to have the same space dimensions." + ) + return space_dims.pop() + + @property + def grid(self): + grids = [t.grid for t in self.targets] + if len(set(grids)) > 1: + raise ValueError( + "All targets within a PETScSolve have to have the same grid." + ) + return grids.pop() + + @property + def space_order(self): + # NOTE: since we use DMDA to create vecs for the coupled solves, + # all fields must have the same space order + # ... re think this? limitation. For now, just force the + # space order to be the same. + # This isn't a problem for segregated solves. + space_orders = [t.space_order for t in self.targets] + if len(set(space_orders)) > 1: + raise ValueError( + "All targets within a PETScSolve have to have the same space order." + ) + return space_orders.pop() + + @property + def submatrices(self): + return self._submatrices + + +class SubMatrices: + def __init__(self, targets): + self.targets = targets + self.submatrices = self._initialize_submatrices() + + def _initialize_submatrices(self): + """ + Create a dict of submatrices for each target with metadata. + """ + submatrices = {} + num_targets = len(self.targets) + + for i, target in enumerate(self.targets): + submatrices[target] = {} + for j in range(num_targets): + key = f'J{i}{j}' + submatrices[target][key] = { + 'matvecs': None, + 'derivative_wrt': self.targets[j], + 'index': i * num_targets + j + } + + return submatrices + + @property + def submatrix_keys(self): + """ + Return a list of all submatrix keys (e.g., ['J00', 'J01', 'J10', 'J11']). + """ + return [key for submats in self.submatrices.values() for key in submats.keys()] + + @property + def nonzero_submatrix_keys(self): + """ + Returns a list of submats where 'matvecs' is not None. + """ + return [ + key + for submats in self.submatrices.values() + for key, value in submats.items() + if value['matvecs'] is not None + ] + + @property + def submat_to_index(self): + """ + Returns a dict mapping submatrix keys to their index. + """ + return { + key: value['index'] + for submats in self.submatrices.values() + for key, value in submats.items() + } + + def set_submatrix(self, field, key, matvecs): + """ + Set a specific submatrix for a field. + + Parameters + ---------- + field : Function + The target field that the submatrix operates on. + key: str + The identifier for the submatrix (e.g., 'J00', 'J01'). + matvecs: list of Eq + The matrix-vector equations forming the submatrix. + """ + if field in self.submatrices and key in self.submatrices[field]: + self.submatrices[field][key]["matvecs"] = matvecs + else: + raise KeyError(f'Invalid field ({field}) or submatrix key ({key})') + + def get_submatrix(self, field, key): + """ + Retrieve a specific submatrix. + """ + return self.submatrices.get(field, {}).get(key, None) + + def __repr__(self): + return str(self.submatrices) diff --git a/devito/petsc/utils.py b/devito/petsc/utils.py new file mode 100644 index 0000000000..782b8501d3 --- /dev/null +++ b/devito/petsc/utils.py @@ -0,0 +1,82 @@ +import os +from pathlib import Path + +from devito.tools import memoized_func + + +solver_mapper = { + 'gmres': 'KSPGMRES', + 'jacobi': 'PCJACOBI', + None: 'PCNONE' +} + + +@memoized_func +def get_petsc_dir(): + # *** First try: via commonly used environment variables + for i in ['PETSC_DIR']: + petsc_dir = os.environ.get(i) + if petsc_dir: + return petsc_dir + # TODO: Raise error if PETSC_DIR is not set + return None + + +@memoized_func +def get_petsc_arch(): + # *** First try: via commonly used environment variables + for i in ['PETSC_ARCH']: + petsc_arch = os.environ.get(i) + if petsc_arch: + return petsc_arch + # TODO: Raise error if PETSC_ARCH is not set + return None + + +def core_metadata(): + petsc_dir = get_petsc_dir() + petsc_arch = get_petsc_arch() + + # Include directories + global_include = os.path.join(petsc_dir, 'include') + config_specific_include = os.path.join(petsc_dir, f'{petsc_arch}', 'include') + include_dirs = (global_include, config_specific_include) + + # Lib directories + lib_dir = os.path.join(petsc_dir, f'{petsc_arch}', 'lib') + + return { + 'includes': ('petscsnes.h', 'petscdmda.h'), + 'include_dirs': include_dirs, + 'libs': ('petsc'), + 'lib_dirs': lib_dir, + 'ldflags': ('-Wl,-rpath,%s' % lib_dir) + } + + +@memoized_func +def get_petsc_variables(): + """ + Taken from https://www.firedrakeproject.org/_modules/firedrake/petsc.html + + Get a dict of PETSc environment variables from the file: + $PETSC_DIR/$PETSC_ARCH/lib/petsc/conf/petscvariables + """ + petsc_dir = get_petsc_dir() + petsc_arch = get_petsc_arch() + path = [petsc_dir, petsc_arch, 'lib', 'petsc', 'conf', 'petscvariables'] + variables_path = Path(*path) + + with open(variables_path) as fh: + # Split lines on first '=' (assignment) + splitlines = (line.split("=", maxsplit=1) for line in fh.readlines()) + return {k.strip(): v.strip() for k, v in splitlines} + + +@memoized_func +def get_petsc_precision(): + """ + Get the PETSc precision. + """ + petsc_variables = get_petsc_variables() + return petsc_variables['PETSC_PRECISION'] diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index 5b13262ded..72413bcaaa 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -10,7 +10,7 @@ from devito.finite_differences.elementary import Min, Max from devito.tools import (Pickable, Bunch, as_tuple, is_integer, float2, # noqa float3, float4, double2, double3, double4, int2, int3, - int4) + int4, CustomIntType) from devito.types import Symbol from devito.types.basic import Basic @@ -19,8 +19,8 @@ 'ListInitializer', 'Byref', 'IndexedPointer', 'Cast', 'DefFunction', 'MathFunction', 'InlineIf', 'ReservedWord', 'Keyword', 'String', 'Macro', 'Class', 'MacroArgument', 'CustomType', 'Deref', 'Namespace', - 'Rvalue', 'INT', 'FLOAT', 'DOUBLE', 'VOID', 'Null', 'SizeOf', 'rfunc', - 'cast_mapper', 'BasicWrapperMixin', 'ValueLimit', 'limits_mapper'] + 'Rvalue', 'INT', 'FLOAT', 'DOUBLE', 'VOID', 'VOIDP', 'Null', 'SizeOf', 'rfunc', + 'cast_mapper', 'BasicWrapperMixin', 'ValueLimit', 'limits_mapper', 'Mod'] class CondEq(sympy.Eq): @@ -90,9 +90,16 @@ def __new__(cls, lhs, rhs, params=None): # Perhaps it's a symbolic RHS -- but we wanna be sure it's of type int if not hasattr(rhs, 'dtype'): raise ValueError("Symbolic RHS `%s` lacks dtype" % rhs) - if not issubclass(rhs.dtype, np.integer): - raise ValueError("Symbolic RHS `%s` must be of type `int`, found " - "`%s` instead" % (rhs, rhs.dtype)) + + # TODO: Move into a utility function? + is_int_type = isinstance(rhs.dtype, type) and \ + issubclass(rhs.dtype, np.integer) + is_custom_int_type = isinstance(rhs.dtype, CustomIntType) + assert is_int_type or is_custom_int_type, ( + f"Symbolic RHS `{rhs}` must be of type `int`, " + f"found `{rhs.dtype}` instead" + ) + rhs = sympify(rhs) obj = sympy.Expr.__new__(cls, lhs, rhs) @@ -115,6 +122,26 @@ def __mul__(self, other): return super().__mul__(other) +class Mod(sympy.Expr): + # TODO: Add tests + is_Atom = True + is_commutative = True + + def __new__(cls, lhs, rhs, params=None): + rhs = sympify(rhs) + + obj = sympy.Expr.__new__(cls, lhs, rhs) + + obj.lhs = lhs + obj.rhs = rhs + return obj + + def __str__(self): + return "Mod(%s, %s)" % (self.lhs, self.rhs) + + __repr__ = __str__ + + class BasicWrapperMixin: """ @@ -167,7 +194,7 @@ def __new__(cls, call, pointer, params=None, **kwargs): pointer = Symbol(pointer) if isinstance(call, str): call = Symbol(call) - elif not isinstance(call, Basic): + elif not isinstance(call.base, Basic): raise ValueError("`call` must be a `devito.Basic` or a type " "with compatible interface") _params = [] @@ -252,6 +279,10 @@ def __str__(self): def field(self): return self.call + @property + def dtype(self): + return self.field.dtype + __repr__ = __str__ @@ -819,6 +850,10 @@ class VOID(Cast): _base_typ = 'void' +class VOIDP(CastStar): + base = VOID + + class CHARP(CastStar): base = CHAR diff --git a/devito/symbolics/inspection.py b/devito/symbolics/inspection.py index 18e2623764..123a8c46e4 100644 --- a/devito/symbolics/inspection.py +++ b/devito/symbolics/inspection.py @@ -295,10 +295,10 @@ def sympy_dtype(expr, base=None): """ Infer the dtype of the expression. """ + # TODO: Edit/fix/update according to PR #2513 dtypes = {base} - {None} - for i in expr.free_symbols: - try: - dtypes.add(i.dtype) - except AttributeError: - pass + for i in expr.args: + dtype = getattr(i, 'dtype', None) + if dtype: + dtypes.add(dtype) return infer_dtype(dtypes) diff --git a/devito/tools/dtypes_lowering.py b/devito/tools/dtypes_lowering.py index b5b564a4d7..7c3309d404 100644 --- a/devito/tools/dtypes_lowering.py +++ b/devito/tools/dtypes_lowering.py @@ -13,7 +13,7 @@ 'double3', 'double4', 'dtypes_vector_mapper', 'dtype_to_mpidtype', 'dtype_to_cstr', 'dtype_to_ctype', 'dtype_to_mpitype', 'dtype_len', 'ctypes_to_cstr', 'c_restrict_void_p', 'ctypes_vector_mapper', - 'is_external_ctype', 'infer_dtype', 'CustomDtype'] + 'is_external_ctype', 'infer_dtype', 'CustomDtype', 'CustomIntType'] # *** Custom np.dtypes @@ -123,6 +123,11 @@ def __repr__(self): __str__ = __repr__ +# TODO: Consider if this should be an instance instead of a subclass? +class CustomIntType(CustomDtype): + pass + + # *** np.dtypes lowering @@ -278,6 +283,8 @@ def is_external_ctype(ctype, includes): True if `ctype` is known to be declared in one of the given `includes` files, False otherwise. """ + if isinstance(ctype, CustomDtype): + return False # Get the base type while issubclass(ctype, ctypes._Pointer): ctype = ctype._type_ diff --git a/devito/types/array.py b/devito/types/array.py index 44d7fabf9d..d2be10eb67 100644 --- a/devito/types/array.py +++ b/devito/types/array.py @@ -60,6 +60,21 @@ def shape_allocated(self): def is_const(self): return self._is_const + @property + def _C_free(self): + """ + A symbolic destructor for the Array, injected in the generated code. + + Notes + ----- + To be overridden by subclasses, ignored otherwise. + """ + return None + + @property + def _C_free_priority(self): + return 0 + class Array(ArrayBasic): diff --git a/devito/types/basic.py b/devito/types/basic.py index bd04be8564..4a46643c3d 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -14,7 +14,7 @@ from devito.data import default_allocator from devito.parameters import configuration from devito.tools import (Pickable, as_tuple, ctypes_to_cstr, dtype_to_ctype, - frozendict, memoized_meth, sympy_mutex) + frozendict, memoized_meth, sympy_mutex, CustomDtype) from devito.types.args import ArgProvider from devito.types.caching import Cached, Uncached from devito.types.lazy import Evaluable @@ -84,6 +84,9 @@ def _C_typedata(self): The type of the object in the generated code as a `str`. """ _type = self._C_ctype + if isinstance(_type, CustomDtype): + return ctypes_to_cstr(_type) + while issubclass(_type, _Pointer): _type = _type._type_ @@ -219,6 +222,10 @@ def _mem_shared_remote(self): """ return False + @property + def _restrict_keyword(self): + return 'restrict' + class Basic(CodeSymbol): diff --git a/devito/types/equation.py b/devito/types/equation.py index 3b49213f92..0d33e3debf 100644 --- a/devito/types/equation.py +++ b/devito/types/equation.py @@ -219,3 +219,7 @@ class ReduceMax(Reduction): class ReduceMin(Reduction): pass + + +class PetscEq(Eq): + pass diff --git a/devito/types/object.py b/devito/types/object.py index cba54b0add..df8f94c171 100644 --- a/devito/types/object.py +++ b/devito/types/object.py @@ -1,14 +1,14 @@ from ctypes import byref - import sympy -from devito.tools import Pickable, as_tuple, sympy_mutex +from devito.tools import Pickable, as_tuple, sympy_mutex, CustomDtype from devito.types.args import ArgProvider from devito.types.caching import Uncached from devito.types.basic import Basic, LocalType from devito.types.utils import CtypesFactory -__all__ = ['Object', 'LocalObject', 'CompositeObject'] + +__all__ = ['Object', 'LocalObject', 'CompositeObject', 'LocalCompositeObject'] class AbstractObject(Basic, sympy.Basic, Pickable): @@ -138,6 +138,7 @@ def __init__(self, name, pname, pfields, value=None): dtype = CtypesFactory.generate(pname, pfields) value = self.__value_setup__(dtype, value) super().__init__(name, dtype, value) + self._pname = pname def __value_setup__(self, dtype, value): return value or byref(dtype._type_()) @@ -148,7 +149,7 @@ def pfields(self): @property def pname(self): - return self.dtype._type_.__name__ + return self._pname @property def fields(self): @@ -231,6 +232,39 @@ def _C_free(self): """ return None + @property + def _C_free_priority(self): + return float('inf') + @property def _mem_global(self): return self._is_global + + +class LocalCompositeObject(CompositeObject, LocalType): + + """ + Object with composite type (e.g., a C struct) defined in C. + """ + + __rargs__ = ('name', 'pname', 'fields') + + def __init__(self, name, pname, fields, modifier=None, liveness='lazy'): + dtype = CustomDtype(f"struct {pname}", modifier=modifier) + Object.__init__(self, name, dtype, None) + self._pname = pname + assert liveness in ['eager', 'lazy'] + self._liveness = liveness + self._fields = fields + + @property + def fields(self): + return self._fields + + @property + def _fields_(self): + return [(i._C_name, i._C_ctype) for i in self.fields] + + @property + def __name__(self): + return self.pname diff --git a/docker/Dockerfile.cpu b/docker/Dockerfile.cpu index 8f036badb9..04da0dcc8d 100644 --- a/docker/Dockerfile.cpu +++ b/docker/Dockerfile.cpu @@ -17,7 +17,8 @@ RUN apt-get update && \ # Install for basic base not containing it RUN apt-get install -y wget flex libnuma-dev hwloc curl cmake git \ - autoconf libtool build-essential procps software-properties-common + autoconf libtool build-essential procps software-properties-common \ + gfortran pkgconf libopenblas-serial-dev # Install gcc RUN if [ -n "$gcc" ]; then \ @@ -43,7 +44,6 @@ RUN cd /tmp && mkdir openmpi && \ cd openmpi && ./autogen.pl && \ mkdir build && cd build && \ ../configure --prefix=/opt/openmpi/ \ - --disable-mpi-fortran \ --enable-mca-no-build=btl-uct --enable-mpi1-compatibility && \ make -j ${nproc} && \ make install && \ diff --git a/docker/Dockerfile.devito b/docker/Dockerfile.devito index f167c36e54..ba41309fd9 100644 --- a/docker/Dockerfile.devito +++ b/docker/Dockerfile.devito @@ -4,8 +4,26 @@ # Base image with compilers ARG base=devitocodes/bases:cpu-gcc +ARG petscinstall="" -FROM $base AS builder +FROM $base AS copybase + +################## Install PETSc ############################################ +FROM copybase AS petsccopybase + +RUN apt-get update && apt-get install -y git && \ + python3 -m venv /venv && \ + /venv/bin/pip install --no-cache-dir --upgrade pip && \ + /venv/bin/pip install --no-cache-dir --no-binary numpy numpy && \ + mkdir -p /opt/petsc && \ + cd /opt/petsc && \ + git clone -b release https://gitlab.com/petsc/petsc.git petsc && \ + cd petsc && \ + ./configure --with-fortran-bindings=0 --with-mpi-dir=/opt/openmpi --with-openblas-include=$(pkg-config --variable=includedir openblas) --with-openblas-lib=$(pkg-config --variable=libdir openblas)/libopenblas.so PETSC_ARCH=devito_build && \ + make all + +ARG petscinstall="" +FROM ${petscinstall}copybase AS builder # User/Group Ids ARG USER_ID=1000 @@ -69,6 +87,9 @@ ARG GROUP_ID=1000 ENV HOME=/app ENV APP_HOME=/app +ENV PETSC_ARCH="devito_build" +ENV PETSC_DIR="/opt/petsc/petsc" + # Create the home directory for the new app user. # Create an app user so our program doesn't run as root. # Chown all the files to the app user. @@ -102,4 +123,3 @@ USER app EXPOSE 8888 ENTRYPOINT ["/docker-entrypoint.sh"] CMD ["/jupyter"] - diff --git a/examples/petsc/init_test.py b/examples/petsc/init_test.py new file mode 100644 index 0000000000..462865831f --- /dev/null +++ b/examples/petsc/init_test.py @@ -0,0 +1,8 @@ +import os +from devito.petsc.initialize import PetscInitialize +from devito import configuration +configuration['compiler'] = 'custom' +os.environ['CC'] = 'mpicc' + +PetscInitialize() +print("helloworld") diff --git a/examples/petsc/makefile b/examples/petsc/makefile new file mode 100644 index 0000000000..ca5d2a9f57 --- /dev/null +++ b/examples/petsc/makefile @@ -0,0 +1,5 @@ +-include ${PETSC_DIR}/petscdir.mk +include ${PETSC_DIR}/lib/petsc/conf/variables +include ${PETSC_DIR}/lib/petsc/conf/rules + +all: test diff --git a/examples/petsc/petsc_test.py b/examples/petsc/petsc_test.py new file mode 100644 index 0000000000..69b63c99ec --- /dev/null +++ b/examples/petsc/petsc_test.py @@ -0,0 +1,29 @@ +import os +import numpy as np + +from devito import (Grid, Function, Eq, Operator, configuration) +from devito.petsc import PETScSolve +from devito.petsc.initialize import PetscInitialize +configuration['compiler'] = 'custom' +os.environ['CC'] = 'mpicc' + +PetscInitialize() + + +nx = 81 +ny = 81 + +grid = Grid(shape=(nx, ny), extent=(2., 2.), dtype=np.float64) + +u = Function(name='u', grid=grid, dtype=np.float64, space_order=2) +v = Function(name='v', grid=grid, dtype=np.float64, space_order=2) + +v.data[:] = 5.0 + +eq = Eq(v, u.laplace, subdomain=grid.interior) + +petsc = PETScSolve([eq], u) + +op = Operator(petsc) + +op.apply() diff --git a/examples/petsc/test_init.c b/examples/petsc/test_init.c new file mode 100644 index 0000000000..8d92ce176e --- /dev/null +++ b/examples/petsc/test_init.c @@ -0,0 +1,26 @@ +#include + +extern PetscErrorCode PetscInit(); +extern PetscErrorCode PetscFinal(); + +int main(int argc, char **argv) +{ + PetscInit(argc, argv); + PetscPrintf(PETSC_COMM_WORLD, "Hello World!\n"); + return PetscFinalize(); +} + +PetscErrorCode PetscInit(int argc, char **argv) +{ + static char help[] = "Magic help string\n"; + PetscFunctionBeginUser; + PetscCall(PetscInitialize(&argc, &argv, NULL, help)); + PetscFunctionReturn(0); +} + +PetscErrorCode PetscFinal() +{ + PetscFunctionBeginUser; + PetscCall(PetscFinalize()); + PetscFunctionReturn(0); +} diff --git a/tests/test_iet.py b/tests/test_iet.py index ff963d5518..c1c78eb7ea 100644 --- a/tests/test_iet.py +++ b/tests/test_iet.py @@ -10,7 +10,8 @@ from devito.ir.iet import (Call, Callable, Conditional, DeviceCall, DummyExpr, Iteration, List, KernelLaunch, Lambda, ElementalFunction, CGen, FindSymbols, filter_iterations, make_efunc, - retrieve_iteration_tree, Transformer) + retrieve_iteration_tree, Transformer, Callback, + Definition, FindNodes) from devito.ir import SymbolRegistry from devito.passes.iet.engine import Graph from devito.passes.iet.languages.C import CDataManager @@ -128,6 +129,33 @@ def test_find_symbols_nested(mode, expected): assert [f.name for f in found] == eval(expected) +def test_callback_cgen(): + + class FunctionPtr(Callback): + @property + def callback_form(self): + param_types = ', '.join([str(t) for t in + self.param_types]) + return "(%s (*)(%s))%s" % (self.retval, param_types, self.name) + + a = Symbol('a') + b = Symbol('b') + foo0 = Callable('foo0', Definition(a), 'void', parameters=[b]) + foo0_arg = FunctionPtr(foo0.name, foo0.retval, 'int') + code0 = CGen().visit(foo0_arg) + assert str(code0) == '(void (*)(int))foo0' + + # Test nested calls with a Callback as an argument. + call = Call('foo2', [ + Call('foo1', [foo0_arg]) + ]) + code1 = CGen().visit(call) + assert str(code1) == 'foo2(foo1((void (*)(int))foo0));' + + callees = FindNodes(Call).visit(call) + assert len(callees) == 3 + + def test_list_denesting(): l0 = List(header=cgen.Line('a'), body=List(header=cgen.Line('b'))) l1 = l0._rebuild(body=List(header=cgen.Line('c'))) diff --git a/tests/test_petsc.py b/tests/test_petsc.py new file mode 100644 index 0000000000..cfd5c01414 --- /dev/null +++ b/tests/test_petsc.py @@ -0,0 +1,1047 @@ +import numpy as np +import os +import pytest + +from conftest import skipif +from devito import (Grid, Function, TimeFunction, Eq, Operator, switchconfig, + configuration, norm) +from devito.ir.iet import (Call, ElementalFunction, Definition, DummyExpr, + FindNodes, retrieve_iteration_tree) +from devito.types import Constant, LocalCompositeObject +from devito.passes.iet.languages.C import CDataManager +from devito.petsc.types import (DM, Mat, LocalVec, PetscMPIInt, KSP, + PC, KSPConvergedReason, PETScArray, + LinearSolveExpr, FieldData, MultipleFieldData) +from devito.petsc.solve import PETScSolve, separate_eqn, centre_stencil +from devito.petsc.iet.nodes import Expression +from devito.petsc.initialize import PetscInitialize + + +@skipif('petsc') +def test_petsc_initialization(): + # TODO: Temporary workaround until PETSc is automatically + # initialized + configuration['compiler'] = 'custom' + os.environ['CC'] = 'mpicc' + PetscInitialize() + + +@skipif('petsc') +def test_petsc_local_object(): + """ + Test C++ support for PETSc LocalObjects. + """ + lo0 = DM('da', stencil_width=1) + lo1 = Mat('A') + lo2 = LocalVec('x') + lo3 = PetscMPIInt('size') + lo4 = KSP('ksp') + lo5 = PC('pc') + lo6 = KSPConvergedReason('reason') + + iet = Call('foo', [lo0, lo1, lo2, lo3, lo4, lo5, lo6]) + iet = ElementalFunction('foo', iet, parameters=()) + + dm = CDataManager(sregistry=None) + iet = CDataManager.place_definitions.__wrapped__(dm, iet)[0] + + assert 'DM da;' in str(iet) + assert 'Mat A;' in str(iet) + assert 'Vec x;' in str(iet) + assert 'PetscMPIInt size;' in str(iet) + assert 'KSP ksp;' in str(iet) + assert 'PC pc;' in str(iet) + assert 'KSPConvergedReason reason;' in str(iet) + + +@skipif('petsc') +def test_petsc_functions(): + """ + Test C++ support for PETScArrays. + """ + grid = Grid((2, 2)) + x, y = grid.dimensions + + f0 = Function(name='f', grid=grid, space_order=2, dtype=np.float32) + f1 = Function(name='f', grid=grid, space_order=2, dtype=np.float64) + + ptr0 = PETScArray(name='ptr0', target=f0) + ptr1 = PETScArray(name='ptr1', target=f0, is_const=True) + ptr2 = PETScArray(name='ptr2', target=f1, is_const=True) + + defn0 = Definition(ptr0) + defn1 = Definition(ptr1) + defn2 = Definition(ptr2) + + expr = DummyExpr(ptr0.indexed[x, y], ptr1.indexed[x, y] + 1) + + assert str(defn0) == 'PetscScalar * ptr0_vec;' + assert str(defn1) == 'const PetscScalar * ptr1_vec;' + assert str(defn2) == 'const PetscScalar * ptr2_vec;' + assert str(expr) == 'ptr0[x][y] = ptr1[x][y] + 1;' + + +@skipif('petsc') +def test_petsc_subs(): + """ + Test support for PETScArrays in substitutions. + """ + grid = Grid((2, 2)) + + f1 = Function(name='f1', grid=grid, space_order=2) + f2 = Function(name='f2', grid=grid, space_order=2) + + arr = PETScArray(name='arr', target=f2) + + eqn = Eq(f1, f2.laplace) + eqn_subs = eqn.subs(f2, arr) + + assert str(eqn) == 'Eq(f1(x, y), Derivative(f2(x, y), (x, 2))' + \ + ' + Derivative(f2(x, y), (y, 2)))' + + assert str(eqn_subs) == 'Eq(f1(x, y), Derivative(arr(x, y), (x, 2))' + \ + ' + Derivative(arr(x, y), (y, 2)))' + + assert str(eqn_subs.rhs.evaluate) == '-2.0*arr(x, y)/h_x**2' + \ + ' + arr(x - h_x, y)/h_x**2 + arr(x + h_x, y)/h_x**2 - 2.0*arr(x, y)/h_y**2' + \ + ' + arr(x, y - h_y)/h_y**2 + arr(x, y + h_y)/h_y**2' + + +@skipif('petsc') +def test_petsc_solve(): + """ + Test PETScSolve. + """ + grid = Grid((2, 2), dtype=np.float64) + + f = Function(name='f', grid=grid, space_order=2) + g = Function(name='g', grid=grid, space_order=2) + + eqn = Eq(f.laplace, g) + + petsc = PETScSolve(eqn, f) + + with switchconfig(openmp=False): + op = Operator(petsc, opt='noop') + + callable_roots = [meta_call.root for meta_call in op._func_table.values()] + + matvec_callback = [root for root in callable_roots if root.name == 'MatMult0'] + + formrhs_callback = [root for root in callable_roots if root.name == 'FormRHS0'] + + action_expr = FindNodes(Expression).visit(matvec_callback[0]) + rhs_expr = FindNodes(Expression).visit(formrhs_callback[0]) + + assert str(action_expr[-1].expr.rhs) == ( + 'x_f[x + 1, y + 2]/ctx0->h_x**2' + ' - 2.0*x_f[x + 2, y + 2]/ctx0->h_x**2' + ' + x_f[x + 3, y + 2]/ctx0->h_x**2' + ' + x_f[x + 2, y + 1]/ctx0->h_y**2' + ' - 2.0*x_f[x + 2, y + 2]/ctx0->h_y**2' + ' + x_f[x + 2, y + 3]/ctx0->h_y**2' + ) + + assert str(rhs_expr[-1].expr.rhs) == 'g[x + 2, y + 2]' + + # Check the iteration bounds are correct. + assert op.arguments().get('x_m') == 0 + assert op.arguments().get('y_m') == 0 + assert op.arguments().get('y_M') == 1 + assert op.arguments().get('x_M') == 1 + + assert len(retrieve_iteration_tree(op)) == 0 + + # TODO: Remove pragmas from PETSc callback functions + assert len(matvec_callback[0].parameters) == 3 + + +@skipif('petsc') +def test_multiple_petsc_solves(): + """ + Test multiple PETScSolves. + """ + grid = Grid((2, 2), dtype=np.float64) + + f1 = Function(name='f1', grid=grid, space_order=2) + g1 = Function(name='g1', grid=grid, space_order=2) + + f2 = Function(name='f2', grid=grid, space_order=2) + g2 = Function(name='g2', grid=grid, space_order=2) + + eqn1 = Eq(f1.laplace, g1) + eqn2 = Eq(f2.laplace, g2) + + petsc1 = PETScSolve(eqn1, f1) + petsc2 = PETScSolve(eqn2, f2) + + with switchconfig(openmp=False): + op = Operator(petsc1+petsc2, opt='noop') + + callable_roots = [meta_call.root for meta_call in op._func_table.values()] + + # One FormRHS, MatShellMult, FormFunction, PopulateMatContext per solve + assert len(callable_roots) == 8 + + +@skipif('petsc') +def test_petsc_cast(): + """ + Test casting of PETScArray. + """ + grid1 = Grid((2), dtype=np.float64) + grid2 = Grid((2, 2), dtype=np.float64) + grid3 = Grid((4, 5, 6), dtype=np.float64) + + f1 = Function(name='f1', grid=grid1, space_order=2) + f2 = Function(name='f2', grid=grid2, space_order=4) + f3 = Function(name='f3', grid=grid3, space_order=6) + + eqn1 = Eq(f1.laplace, 10) + eqn2 = Eq(f2.laplace, 10) + eqn3 = Eq(f3.laplace, 10) + + petsc1 = PETScSolve(eqn1, f1) + petsc2 = PETScSolve(eqn2, f2) + petsc3 = PETScSolve(eqn3, f3) + + with switchconfig(openmp=False): + op1 = Operator(petsc1, opt='noop') + op2 = Operator(petsc2, opt='noop') + op3 = Operator(petsc3, opt='noop') + + cb1 = [meta_call.root for meta_call in op1._func_table.values()] + cb2 = [meta_call.root for meta_call in op2._func_table.values()] + cb3 = [meta_call.root for meta_call in op3._func_table.values()] + + assert 'double (*restrict x_f1) = ' + \ + '(double (*)) x_f1_vec;' in str(cb1[0]) + assert 'double (*restrict x_f2)[info.gxm] = ' + \ + '(double (*)[info.gxm]) x_f2_vec;' in str(cb2[0]) + assert 'double (*restrict x_f3)[info.gym][info.gxm] = ' + \ + '(double (*)[info.gym][info.gxm]) x_f3_vec;' in str(cb3[0]) + + +@skipif('petsc') +def test_LinearSolveExpr(): + + grid = Grid((2, 2), dtype=np.float64) + + f = Function(name='f', grid=grid, space_order=2) + g = Function(name='g', grid=grid, space_order=2) + + eqn = Eq(f, g.laplace) + + linsolveexpr = LinearSolveExpr(eqn.rhs, fielddata=FieldData(target=f)) + + # Check the solver parameters + assert linsolveexpr.solver_parameters == \ + {'ksp_type': 'gmres', 'pc_type': 'jacobi', 'ksp_rtol': 1e-05, + 'ksp_atol': 1e-50, 'ksp_divtol': 100000.0, 'ksp_max_it': 10000} + + +@skipif('petsc') +def test_dmda_create(): + + grid1 = Grid((2), dtype=np.float64) + grid2 = Grid((2, 2), dtype=np.float64) + grid3 = Grid((4, 5, 6), dtype=np.float64) + + f1 = Function(name='f1', grid=grid1, space_order=2) + f2 = Function(name='f2', grid=grid2, space_order=4) + f3 = Function(name='f3', grid=grid3, space_order=6) + + eqn1 = Eq(f1.laplace, 10) + eqn2 = Eq(f2.laplace, 10) + eqn3 = Eq(f3.laplace, 10) + + petsc1 = PETScSolve(eqn1, f1) + petsc2 = PETScSolve(eqn2, f2) + petsc3 = PETScSolve(eqn3, f3) + + with switchconfig(openmp=False): + op1 = Operator(petsc1, opt='noop') + op2 = Operator(petsc2, opt='noop') + op3 = Operator(petsc3, opt='noop') + + assert 'PetscCall(DMDACreate1d(PETSC_COMM_WORLD,DM_BOUNDARY_GHOSTED,' + \ + '2,1,2,NULL,&(da0)));' in str(op1) + + assert 'PetscCall(DMDACreate2d(PETSC_COMM_WORLD,DM_BOUNDARY_GHOSTED,' + \ + 'DM_BOUNDARY_GHOSTED,DMDA_STENCIL_BOX,2,2,1,1,1,4,NULL,NULL,&(da0)));' \ + in str(op2) + + assert 'PetscCall(DMDACreate3d(PETSC_COMM_WORLD,DM_BOUNDARY_GHOSTED,' + \ + 'DM_BOUNDARY_GHOSTED,DM_BOUNDARY_GHOSTED,DMDA_STENCIL_BOX,6,5,4' + \ + ',1,1,1,1,6,NULL,NULL,NULL,&(da0)));' in str(op3) + + +@skipif('petsc') +def test_cinterface_petsc_struct(): + + grid = Grid(shape=(11, 11), dtype=np.float64) + f = Function(name='f', grid=grid, space_order=2) + eq = Eq(f.laplace, 10) + petsc = PETScSolve(eq, f) + + name = "foo" + with switchconfig(openmp=False): + op = Operator(petsc, name=name) + + # Trigger the generation of a .c and a .h files + ccode, hcode = op.cinterface(force=True) + + dirname = op._compiler.get_jit_dir() + assert os.path.isfile(os.path.join(dirname, "%s.c" % name)) + assert os.path.isfile(os.path.join(dirname, "%s.h" % name)) + + ccode = str(ccode) + hcode = str(hcode) + + assert 'include "%s.h"' % name in ccode + + # The public `struct UserCtx` only appears in the header file + assert 'struct UserCtx0\n{' not in ccode + assert 'struct UserCtx0\n{' in hcode + + +@skipif('petsc') +@pytest.mark.parametrize('eqn, target, expected', [ + ('Eq(f1.laplace, g1)', + 'f1', ('g1(x, y)', 'Derivative(f1(x, y), (x, 2)) + Derivative(f1(x, y), (y, 2))')), + ('Eq(g1, f1.laplace)', + 'f1', ('-g1(x, y)', '-Derivative(f1(x, y), (x, 2)) - Derivative(f1(x, y), (y, 2))')), + ('Eq(g1, f1.laplace)', 'g1', + ('Derivative(f1(x, y), (x, 2)) + Derivative(f1(x, y), (y, 2))', 'g1(x, y)')), + ('Eq(f1 + f1.laplace, g1)', 'f1', ('g1(x, y)', + 'f1(x, y) + Derivative(f1(x, y), (x, 2)) + Derivative(f1(x, y), (y, 2))')), + ('Eq(g1.dx + f1.dx, g1)', 'f1', + ('g1(x, y) - Derivative(g1(x, y), x)', 'Derivative(f1(x, y), x)')), + ('Eq(g1.dx + f1.dx, g1)', 'g1', + ('-Derivative(f1(x, y), x)', '-g1(x, y) + Derivative(g1(x, y), x)')), + ('Eq(f1 * g1.dx, g1)', 'g1', ('0', 'f1(x, y)*Derivative(g1(x, y), x) - g1(x, y)')), + ('Eq(f1 * g1.dx, g1)', 'f1', ('g1(x, y)', 'f1(x, y)*Derivative(g1(x, y), x)')), + ('Eq((f1 * g1.dx).dy, f1)', 'f1', + ('0', '-f1(x, y) + Derivative(f1(x, y)*Derivative(g1(x, y), x), y)')), + ('Eq((f1 * g1.dx).dy, f1)', 'g1', + ('f1(x, y)', 'Derivative(f1(x, y)*Derivative(g1(x, y), x), y)')), + ('Eq(f2.laplace, g2)', 'g2', + ('-Derivative(f2(t, x, y), (x, 2)) - Derivative(f2(t, x, y), (y, 2))', + '-g2(t, x, y)')), + ('Eq(f2.laplace, g2)', 'f2', ('g2(t, x, y)', + 'Derivative(f2(t, x, y), (x, 2)) + Derivative(f2(t, x, y), (y, 2))')), + ('Eq(f2.laplace, f2)', 'f2', ('0', + '-f2(t, x, y) + Derivative(f2(t, x, y), (x, 2)) + Derivative(f2(t, x, y), (y, 2))')), + ('Eq(f2*g2, f2)', 'f2', ('0', 'f2(t, x, y)*g2(t, x, y) - f2(t, x, y)')), + ('Eq(f2*g2, f2)', 'g2', ('f2(t, x, y)', 'f2(t, x, y)*g2(t, x, y)')), + ('Eq(g2*f2.laplace, f2)', 'g2', ('f2(t, x, y)', + '(Derivative(f2(t, x, y), (x, 2)) + Derivative(f2(t, x, y), (y, 2)))*g2(t, x, y)')), + ('Eq(f2.forward, f2)', 'f2.forward', ('f2(t, x, y)', 'f2(t + dt, x, y)')), + ('Eq(f2.forward, f2)', 'f2', ('-f2(t + dt, x, y)', '-f2(t, x, y)')), + ('Eq(f2.forward.laplace, f2)', 'f2.forward', ('f2(t, x, y)', + 'Derivative(f2(t + dt, x, y), (x, 2)) + Derivative(f2(t + dt, x, y), (y, 2))')), + ('Eq(f2.forward.laplace, f2)', 'f2', + ('-Derivative(f2(t + dt, x, y), (x, 2)) - Derivative(f2(t + dt, x, y), (y, 2))', + '-f2(t, x, y)')), + ('Eq(f2.laplace + f2.forward.laplace, g2)', 'f2.forward', + ('g2(t, x, y) - Derivative(f2(t, x, y), (x, 2)) - Derivative(f2(t, x, y), (y, 2))', + 'Derivative(f2(t + dt, x, y), (x, 2)) + Derivative(f2(t + dt, x, y), (y, 2))')), + ('Eq(g2.laplace, f2 + g2.forward)', 'g2.forward', + ('f2(t, x, y) - Derivative(g2(t, x, y), (x, 2)) - Derivative(g2(t, x, y), (y, 2))', + '-g2(t + dt, x, y)')) +]) +def test_separate_eqn(eqn, target, expected): + """ + Test the separate_eqn function. + + This function is called within PETScSolve to decompose the equation + into the form F(x) = b. This is necessary to utilise the SNES + interface in PETSc. + """ + grid = Grid((2, 2)) + + so = 2 + + f1 = Function(name='f1', grid=grid, space_order=so) # noqa + g1 = Function(name='g1', grid=grid, space_order=so) # noqa + + f2 = TimeFunction(name='f2', grid=grid, space_order=so) # noqa + g2 = TimeFunction(name='g2', grid=grid, space_order=so) # noqa + + b, F, _ = separate_eqn(eval(eqn), eval(target)) + expected_b, expected_F = expected + + assert str(b) == expected_b + assert str(F) == expected_F + + +@skipif('petsc') +@pytest.mark.parametrize('eqn, target, expected', [ + ('Eq(f1.laplace, g1).evaluate', 'f1', + ( + 'g1(x, y)', + '-2.0*f1(x, y)/h_x**2 + f1(x - h_x, y)/h_x**2 + f1(x + h_x, y)/h_x**2 ' + '- 2.0*f1(x, y)/h_y**2 + f1(x, y - h_y)/h_y**2 + f1(x, y + h_y)/h_y**2' + )), + ('Eq(g1, f1.laplace).evaluate', 'f1', + ( + '-g1(x, y)', + '-(-2.0*f1(x, y)/h_x**2 + f1(x - h_x, y)/h_x**2 + f1(x + h_x, y)/h_x**2) ' + '- (-2.0*f1(x, y)/h_y**2 + f1(x, y - h_y)/h_y**2 + f1(x, y + h_y)/h_y**2)' + )), + ('Eq(g1, f1.laplace).evaluate', 'g1', + ( + '-2.0*f1(x, y)/h_x**2 + f1(x - h_x, y)/h_x**2 + f1(x + h_x, y)/h_x**2 ' + '- 2.0*f1(x, y)/h_y**2 + f1(x, y - h_y)/h_y**2 + f1(x, y + h_y)/h_y**2', + 'g1(x, y)' + )), + ('Eq(f1 + f1.laplace, g1).evaluate', 'f1', + ( + 'g1(x, y)', + '-2.0*f1(x, y)/h_x**2 + f1(x - h_x, y)/h_x**2 + f1(x + h_x, y)/h_x**2 - 2.0' + '*f1(x, y)/h_y**2 + f1(x, y - h_y)/h_y**2 + f1(x, y + h_y)/h_y**2 + f1(x, y)' + )), + ('Eq(g1.dx + f1.dx, g1).evaluate', 'f1', + ( + '-(-g1(x, y)/h_x + g1(x + h_x, y)/h_x) + g1(x, y)', + '-f1(x, y)/h_x + f1(x + h_x, y)/h_x' + )), + ('Eq(g1.dx + f1.dx, g1).evaluate', 'g1', + ( + '-(-f1(x, y)/h_x + f1(x + h_x, y)/h_x)', + '-g1(x, y)/h_x + g1(x + h_x, y)/h_x - g1(x, y)' + )), + ('Eq(f1 * g1.dx, g1).evaluate', 'g1', + ( + '0', '(-g1(x, y)/h_x + g1(x + h_x, y)/h_x)*f1(x, y) - g1(x, y)' + )), + ('Eq(f1 * g1.dx, g1).evaluate', 'f1', + ( + 'g1(x, y)', '(-g1(x, y)/h_x + g1(x + h_x, y)/h_x)*f1(x, y)' + )), + ('Eq((f1 * g1.dx).dy, f1).evaluate', 'f1', + ( + '0', '(-1/h_y)*(-g1(x, y)/h_x + g1(x + h_x, y)/h_x)*f1(x, y) ' + '+ (-g1(x, y + h_y)/h_x + g1(x + h_x, y + h_y)/h_x)*f1(x, y + h_y)/h_y ' + '- f1(x, y)' + )), + ('Eq((f1 * g1.dx).dy, f1).evaluate', 'g1', + ( + 'f1(x, y)', '(-1/h_y)*(-g1(x, y)/h_x + g1(x + h_x, y)/h_x)*f1(x, y) + ' + '(-g1(x, y + h_y)/h_x + g1(x + h_x, y + h_y)/h_x)*f1(x, y + h_y)/h_y' + )), + ('Eq(f2.laplace, g2).evaluate', 'g2', + ( + '-(-2.0*f2(t, x, y)/h_x**2 + f2(t, x - h_x, y)/h_x**2 + f2(t, x + h_x, y)' + '/h_x**2) - (-2.0*f2(t, x, y)/h_y**2 + f2(t, x, y - h_y)/h_y**2 + ' + 'f2(t, x, y + h_y)/h_y**2)', '-g2(t, x, y)' + )), + ('Eq(f2.laplace, g2).evaluate', 'f2', + ( + 'g2(t, x, y)', '-2.0*f2(t, x, y)/h_x**2 + f2(t, x - h_x, y)/h_x**2 + ' + 'f2(t, x + h_x, y)/h_x**2 - 2.0*f2(t, x, y)/h_y**2 + f2(t, x, y - h_y)' + '/h_y**2 + f2(t, x, y + h_y)/h_y**2' + )), + ('Eq(f2.laplace, f2).evaluate', 'f2', + ( + '0', '-2.0*f2(t, x, y)/h_x**2 + f2(t, x - h_x, y)/h_x**2 + ' + 'f2(t, x + h_x, y)/h_x**2 - 2.0*f2(t, x, y)/h_y**2 + f2(t, x, y - h_y)/h_y**2' + ' + f2(t, x, y + h_y)/h_y**2 - f2(t, x, y)' + )), + ('Eq(g2*f2.laplace, f2).evaluate', 'g2', + ( + 'f2(t, x, y)', '(-2.0*f2(t, x, y)/h_x**2 + f2(t, x - h_x, y)/h_x**2 + ' + 'f2(t, x + h_x, y)/h_x**2 - 2.0*f2(t, x, y)/h_y**2 + f2(t, x, y - h_y)/h_y**2' + ' + f2(t, x, y + h_y)/h_y**2)*g2(t, x, y)' + )), + ('Eq(f2.forward.laplace, f2).evaluate', 'f2.forward', + ( + 'f2(t, x, y)', '-2.0*f2(t + dt, x, y)/h_x**2 + f2(t + dt, x - h_x, y)/h_x**2' + ' + f2(t + dt, x + h_x, y)/h_x**2 - 2.0*f2(t + dt, x, y)/h_y**2 + ' + 'f2(t + dt, x, y - h_y)/h_y**2 + f2(t + dt, x, y + h_y)/h_y**2' + )), + ('Eq(f2.forward.laplace, f2).evaluate', 'f2', + ( + '-(-2.0*f2(t + dt, x, y)/h_x**2 + f2(t + dt, x - h_x, y)/h_x**2 + ' + 'f2(t + dt, x + h_x, y)/h_x**2) - (-2.0*f2(t + dt, x, y)/h_y**2 + ' + 'f2(t + dt, x, y - h_y)/h_y**2 + f2(t + dt, x, y + h_y)/h_y**2)', + '-f2(t, x, y)' + )), + ('Eq(f2.laplace + f2.forward.laplace, g2).evaluate', 'f2.forward', + ( + '-(-2.0*f2(t, x, y)/h_x**2 + f2(t, x - h_x, y)/h_x**2 + f2(t, x + h_x, y)/' + 'h_x**2) - (-2.0*f2(t, x, y)/h_y**2 + f2(t, x, y - h_y)/h_y**2 + ' + 'f2(t, x, y + h_y)/h_y**2) + g2(t, x, y)', '-2.0*f2(t + dt, x, y)/h_x**2 + ' + 'f2(t + dt, x - h_x, y)/h_x**2 + f2(t + dt, x + h_x, y)/h_x**2 - 2.0*' + 'f2(t + dt, x, y)/h_y**2 + f2(t + dt, x, y - h_y)/h_y**2 + ' + 'f2(t + dt, x, y + h_y)/h_y**2' + )), + ('Eq(g2.laplace, f2 + g2.forward).evaluate', 'g2.forward', + ( + '-(-2.0*g2(t, x, y)/h_x**2 + g2(t, x - h_x, y)/h_x**2 + ' + 'g2(t, x + h_x, y)/h_x**2) - (-2.0*g2(t, x, y)/h_y**2 + g2(t, x, y - h_y)' + '/h_y**2 + g2(t, x, y + h_y)/h_y**2) + f2(t, x, y)', '-g2(t + dt, x, y)' + )) +]) +def test_separate_eval_eqn(eqn, target, expected): + """ + Test the separate_eqn function on pre-evaluated equations. + This ensures that evaluated equations can be passed to PETScSolve, + allowing users to modify stencils for specific boundary conditions, + such as implementing free surface boundary conditions. + """ + grid = Grid((2, 2)) + + so = 2 + + f1 = Function(name='f1', grid=grid, space_order=so) # noqa + g1 = Function(name='g1', grid=grid, space_order=so) # noqa + + f2 = TimeFunction(name='f2', grid=grid, space_order=so) # noqa + g2 = TimeFunction(name='g2', grid=grid, space_order=so) # noqa + + b, F, _ = separate_eqn(eval(eqn), eval(target)) + expected_b, expected_F = expected + + assert str(b) == expected_b + assert str(F) == expected_F + + +@skipif('petsc') +@pytest.mark.parametrize('expr, so, target, expected', [ + ('f1.laplace', 2, 'f1', '-2.0*f1(x, y)/h_y**2 - 2.0*f1(x, y)/h_x**2'), + ('f1 + f1.laplace', 2, 'f1', + 'f1(x, y) - 2.0*f1(x, y)/h_y**2 - 2.0*f1(x, y)/h_x**2'), + ('g1.dx + f1.dx', 2, 'f1', '-f1(x, y)/h_x'), + ('10 + f1.dx2', 2, 'g1', '0'), + ('(f1 * g1.dx).dy', 2, 'f1', + '(-1/h_y)*(-g1(x, y)/h_x + g1(x + h_x, y)/h_x)*f1(x, y)'), + ('(f1 * g1.dx).dy', 2, 'g1', '-(-1/h_y)*f1(x, y)*g1(x, y)/h_x'), + ('f2.laplace', 2, 'f2', '-2.0*f2(t, x, y)/h_y**2 - 2.0*f2(t, x, y)/h_x**2'), + ('f2*g2', 2, 'f2', 'f2(t, x, y)*g2(t, x, y)'), + ('g2*f2.laplace', 2, 'f2', + '(-2.0*f2(t, x, y)/h_y**2 - 2.0*f2(t, x, y)/h_x**2)*g2(t, x, y)'), + ('f2.forward', 2, 'f2.forward', 'f2(t + dt, x, y)'), + ('f2.forward.laplace', 2, 'f2.forward', + '-2.0*f2(t + dt, x, y)/h_y**2 - 2.0*f2(t + dt, x, y)/h_x**2'), + ('f2.laplace + f2.forward.laplace', 2, 'f2.forward', + '-2.0*f2(t + dt, x, y)/h_y**2 - 2.0*f2(t + dt, x, y)/h_x**2'), + ('f2.laplace + f2.forward.laplace', 2, + 'f2', '-2.0*f2(t, x, y)/h_y**2 - 2.0*f2(t, x, y)/h_x**2'), + ('f2.laplace', 4, 'f2', '-2.5*f2(t, x, y)/h_y**2 - 2.5*f2(t, x, y)/h_x**2'), + ('f2.laplace + f2.forward.laplace', 4, 'f2.forward', + '-2.5*f2(t + dt, x, y)/h_y**2 - 2.5*f2(t + dt, x, y)/h_x**2'), + ('f2.laplace + f2.forward.laplace', 4, 'f2', + '-2.5*f2(t, x, y)/h_y**2 - 2.5*f2(t, x, y)/h_x**2'), + ('f2.forward*f2.forward.laplace', 4, 'f2.forward', + '(-2.5*f2(t + dt, x, y)/h_y**2 - 2.5*f2(t + dt, x, y)/h_x**2)*f2(t + dt, x, y)') +]) +def test_centre_stencil(expr, so, target, expected): + """ + Test extraction of centre stencil from an equation. + """ + grid = Grid((2, 2)) + + f1 = Function(name='f1', grid=grid, space_order=so) # noqa + g1 = Function(name='g1', grid=grid, space_order=so) # noqa + + f2 = TimeFunction(name='f2', grid=grid, space_order=so) # noqa + g2 = TimeFunction(name='g2', grid=grid, space_order=so) # noqa + + centre = centre_stencil(eval(expr), eval(target)) + + assert str(centre) == expected + + +@skipif('petsc') +def test_callback_arguments(): + """ + Test the arguments of each callback function. + """ + grid = Grid((2, 2), dtype=np.float64) + + f1 = Function(name='f1', grid=grid, space_order=2) + g1 = Function(name='g1', grid=grid, space_order=2) + + eqn1 = Eq(f1.laplace, g1) + + petsc1 = PETScSolve(eqn1, f1) + + with switchconfig(openmp=False): + op = Operator(petsc1) + + mv = op._func_table['MatMult0'].root + ff = op._func_table['FormFunction0'].root + + assert len(mv.parameters) == 3 + assert len(ff.parameters) == 4 + + assert str(mv.parameters) == '(J, X, Y)' + assert str(ff.parameters) == '(snes, X, F, dummy)' + + +@skipif('petsc') +def test_petsc_struct(): + + grid = Grid((2, 2), dtype=np.float64) + + f1 = Function(name='f1', grid=grid, space_order=2) + g1 = Function(name='g1', grid=grid, space_order=2) + + mu1 = Constant(name='mu1', value=2.0) + mu2 = Constant(name='mu2', value=2.0) + + eqn1 = Eq(f1.laplace, g1*mu1) + petsc1 = PETScSolve(eqn1, f1) + + eqn2 = Eq(f1, g1*mu2) + + with switchconfig(openmp=False): + op = Operator([eqn2] + petsc1) + + arguments = op.arguments() + + # Check mu1 and mu2 in arguments + assert 'mu1' in arguments + assert 'mu2' in arguments + + # Check mu1 and mu2 in op.parameters + assert mu1 in op.parameters + assert mu2 in op.parameters + + # Check PETSc struct not in op.parameters + assert all(not isinstance(i, LocalCompositeObject) for i in op.parameters) + + +@skipif('petsc') +def test_apply(): + + grid = Grid(shape=(13, 13), dtype=np.float64) + + pn = Function(name='pn', grid=grid, space_order=2) + rhs = Function(name='rhs', grid=grid, space_order=2) + mu = Constant(name='mu', value=2.0) + + eqn = Eq(pn.laplace*mu, rhs, subdomain=grid.interior) + + petsc = PETScSolve(eqn, pn) + + # Build the op + op = Operator(petsc) + + # Check the Operator runs without errors + op.apply() + + # Verify that users can override `mu` + mu_new = Constant(name='mu_new', value=4.0) + op.apply(mu=mu_new) + + +@skipif('petsc') +def test_petsc_frees(): + + grid = Grid((2, 2), dtype=np.float64) + + f = Function(name='f', grid=grid, space_order=2) + g = Function(name='g', grid=grid, space_order=2) + + eqn = Eq(f.laplace, g) + petsc = PETScSolve(eqn, f) + + with switchconfig(openmp=False): + op = Operator(petsc) + + frees = op.body.frees + + # Check the frees appear in the following order + assert str(frees[0]) == 'PetscCall(VecDestroy(&(bglobal0)));' + assert str(frees[1]) == 'PetscCall(VecDestroy(&(xglobal0)));' + assert str(frees[2]) == 'PetscCall(MatDestroy(&(J0)));' + assert str(frees[3]) == 'PetscCall(SNESDestroy(&(snes0)));' + assert str(frees[4]) == 'PetscCall(DMDestroy(&(da0)));' + + +@skipif('petsc') +def test_calls_to_callbacks(): + + grid = Grid((2, 2), dtype=np.float64) + + f = Function(name='f', grid=grid, space_order=2) + g = Function(name='g', grid=grid, space_order=2) + + eqn = Eq(f.laplace, g) + petsc = PETScSolve(eqn, f) + + with switchconfig(openmp=False): + op = Operator(petsc) + + ccode = str(op.ccode) + + assert '(void (*)(void))MatMult0' in ccode + assert 'PetscCall(SNESSetFunction(snes0,NULL,FormFunction0,(void*)(da0)));' in ccode + + +@skipif('petsc') +def test_start_ptr(): + """ + Verify that a pointer to the start of the memory address is correctly + generated for TimeFunction objects. This pointer should indicate the + beginning of the multidimensional array that will be overwritten at + the current time step. + This functionality is crucial for VecReplaceArray operations, as it ensures + that the correct memory location is accessed and modified during each time step. + """ + grid = Grid((11, 11), dtype=np.float64) + u1 = TimeFunction(name='u1', grid=grid, space_order=2) + eq1 = Eq(u1.dt, u1.laplace, subdomain=grid.interior) + petsc1 = PETScSolve(eq1, u1.forward) + + with switchconfig(openmp=False): + op1 = Operator(petsc1) + + # Verify the case with modulo time stepping + assert 'double * u1_ptr0 = t1*localsize0 + (double*)(u1_vec->data);' in str(op1) + + # Verify the case with no modulo time stepping + u2 = TimeFunction(name='u2', grid=grid, space_order=2, save=5) + eq2 = Eq(u2.dt, u2.laplace, subdomain=grid.interior) + petsc2 = PETScSolve(eq2, u2.forward) + + with switchconfig(openmp=False): + op2 = Operator(petsc2) + + assert 'double * u2_ptr0 = (time + 1)*localsize0 + ' + \ + '(double*)(u2_vec->data);' in str(op2) + + +@skipif('petsc') +def test_time_loop(): + """ + Verify the following: + - Modulo dimensions are correctly assigned and updated in the PETSc struct + at each time step. + - Only assign/update the modulo dimensions required by any of the + PETSc callback functions. + """ + grid = Grid((11, 11), dtype=np.float64) + + # Modulo time stepping + u1 = TimeFunction(name='u1', grid=grid, space_order=2) + v1 = Function(name='v1', grid=grid, space_order=2) + eq1 = Eq(v1.laplace, u1) + petsc1 = PETScSolve(eq1, v1) + with switchconfig(openmp=False): + op1 = Operator(petsc1) + body1 = str(op1.body) + rhs1 = str(op1._func_table['FormRHS0'].root.ccode) + + assert 'ctx0.t0 = t0' in body1 + assert 'ctx0.t1 = t1' not in body1 + assert 'ctx0->t0' in rhs1 + assert 'ctx0->t1' not in rhs1 + + # Non-modulo time stepping + u2 = TimeFunction(name='u2', grid=grid, space_order=2, save=5) + v2 = Function(name='v2', grid=grid, space_order=2, save=5) + eq2 = Eq(v2.laplace, u2) + petsc2 = PETScSolve(eq2, v2) + with switchconfig(openmp=False): + op2 = Operator(petsc2) + body2 = str(op2.body) + rhs2 = str(op2._func_table['FormRHS0'].root.ccode) + + assert 'ctx0.time = time' in body2 + assert 'ctx0->time' in rhs2 + + # Modulo time stepping with more than one time step + # used in one of the callback functions + eq3 = Eq(v1.laplace, u1 + u1.forward) + petsc3 = PETScSolve(eq3, v1) + with switchconfig(openmp=False): + op3 = Operator(petsc3) + body3 = str(op3.body) + rhs3 = str(op3._func_table['FormRHS0'].root.ccode) + + assert 'ctx0.t0 = t0' in body3 + assert 'ctx0.t1 = t1' in body3 + assert 'ctx0->t0' in rhs3 + assert 'ctx0->t1' in rhs3 + + # Multiple petsc solves within the same time loop + v2 = Function(name='v2', grid=grid, space_order=2) + eq4 = Eq(v1.laplace, u1) + petsc4 = PETScSolve(eq4, v1) + eq5 = Eq(v2.laplace, u1) + petsc5 = PETScSolve(eq5, v2) + with switchconfig(openmp=False): + op4 = Operator(petsc4 + petsc5) + body4 = str(op4.body) + + assert 'ctx0.t0 = t0' in body4 + assert body4.count('ctx0.t0 = t0') == 1 + + +@skipif('petsc') +def test_solve_output(): + """ + Verify that PETScSolve returns the correct output for + simple cases e.g with the identity matrix. + """ + grid = Grid(shape=(11, 11), dtype=np.float64) + + u = Function(name='u', grid=grid, space_order=2) + v = Function(name='v', grid=grid, space_order=2) + + # Solving Ax=b where A is the identity matrix + v.data[:] = 5.0 + eqn = Eq(u, v) + petsc = PETScSolve(eqn, target=u) + + with switchconfig(openmp=False): + op = Operator(petsc) + # Check the solve function returns the correct output + op.apply() + + assert np.allclose(u.data, v.data) + + +class TestCoupledLinear: + # The coupled interface can be used even for uncoupled problems, meaning + # the equations will be solved within a single matrix system. + # These tests use simple problems to validate functionality, but they help + # ensure correctness in code generation. + # TODO: Add more comprehensive tests for fully coupled problems. + # TODO: Add subdomain tests, time loop, multiple coupled etc. + + @skipif('petsc') + def test_coupled_vs_non_coupled(self): + grid = Grid(shape=(11, 11), dtype=np.float64) + + functions = [Function(name=n, grid=grid, space_order=2) + for n in ['e', 'f', 'g', 'h']] + e, f, g, h = functions + + f.data[:] = 5. + h.data[:] = 5. + + eq1 = Eq(e.laplace, f) + eq2 = Eq(g.laplace, h) + + # Non-coupled + petsc1 = PETScSolve(eq1, target=e) + petsc2 = PETScSolve(eq2, target=g) + + with switchconfig(openmp=False): + op1 = Operator(petsc1 + petsc2, opt='noop') + op1.apply() + + enorm1 = norm(e) + gnorm1 = norm(g) + + # Reset + e.data[:] = 0 + g.data[:] = 0 + + # Coupled + # TODO: Need more friendly API for coupled - just + # using a dict for now + petsc3 = PETScSolve({e: [eq1], g: [eq2]}) + with switchconfig(openmp=False): + op2 = Operator(petsc3, opt='noop') + op2.apply() + + enorm2 = norm(e) + gnorm2 = norm(g) + + print('enorm1:', enorm1) + print('enorm2:', enorm2) + assert np.isclose(enorm1, enorm2, rtol=1e-16) + assert np.isclose(gnorm1, gnorm2, rtol=1e-16) + + callbacks1 = [meta_call.root for meta_call in op1._func_table.values()] + callbacks2 = [meta_call.root for meta_call in op2._func_table.values()] + + # Solving for multiple fields within the same matrix system requires + # additional machinery and more callback functions + assert len(callbacks1) == 8 + assert len(callbacks2) == 11 + + # Check fielddata type + fielddata1 = petsc1[0].rhs.fielddata + fielddata2 = petsc2[0].rhs.fielddata + fielddata3 = petsc3[0].rhs.fielddata + + assert isinstance(fielddata1, FieldData) + assert isinstance(fielddata2, FieldData) + assert isinstance(fielddata3, MultipleFieldData) + + @skipif('petsc') + def test_coupled_structs(self): + grid = Grid(shape=(11, 11), dtype=np.float64) + + functions = [Function(name=n, grid=grid, space_order=2) + for n in ['e', 'f', 'g', 'h']] + e, f, g, h = functions + + eq1 = Eq(e + 5, f) + eq2 = Eq(g + 10, h) + + petsc = PETScSolve({f: [eq1], h: [eq2]}) + + name = "foo" + with switchconfig(openmp=False): + op = Operator(petsc, name=name) + + # Trigger the generation of a .c and a .h files + ccode, hcode = op.cinterface(force=True) + + dirname = op._compiler.get_jit_dir() + assert os.path.isfile(os.path.join(dirname, f"{name}.c")) + assert os.path.isfile(os.path.join(dirname, f"{name}.h")) + + ccode = str(ccode) + hcode = str(hcode) + + assert f'include "{name}.h"' in ccode + + # The public `struct JacobianCtx` only appears in the header file + assert 'struct JacobianCtx\n{' not in ccode + assert 'struct JacobianCtx\n{' in hcode + + # The public `struct SubMatrixCtx` only appears in the header file + assert 'struct SubMatrixCtx\n{' not in ccode + assert 'struct SubMatrixCtx\n{' in hcode + + # The public `struct UserCtx0` only appears in the header file + assert 'struct UserCtx0\n{' not in ccode + assert 'struct UserCtx0\n{' in hcode + + @skipif('petsc') + def test_coupled_frees(self): + grid = Grid(shape=(11, 11), dtype=np.float64) + + functions = [Function(name=n, grid=grid, space_order=2) + for n in ['e', 'f', 'g', 'h']] + e, f, g, h = functions + + eq1 = Eq(e.laplace, h) + eq2 = Eq(f.laplace, h) + eq3 = Eq(g.laplace, h) + + petsc1 = PETScSolve({e: [eq1], f: [eq2]}) + petsc2 = PETScSolve({e: [eq1], f: [eq2], g: [eq3]}) + + with switchconfig(openmp=False): + op1 = Operator(petsc1, opt='noop') + op2 = Operator(petsc2, opt='noop') + + frees1 = op1.body.frees + frees2 = op2.body.frees + + # Check solver with two fields + # IS destroys + assert str(frees1[0]) == 'PetscCall(ISDestroy(&(fields0[0])));' + assert str(frees1[1]) == 'PetscCall(ISDestroy(&(fields0[1])));' + assert str(frees1[2]) == 'PetscCall(PetscFree(fields0));' + # Sub DM destroys + assert str(frees1[3]) == 'PetscCall(DMDestroy(&(subdms0[0])));' + assert str(frees1[4]) == 'PetscCall(DMDestroy(&(subdms0[1])));' + assert str(frees1[5]) == 'PetscCall(PetscFree(subdms0));' + + # Check solver with three fields + # IS destroys + assert str(frees2[0]) == 'PetscCall(ISDestroy(&(fields0[0])));' + assert str(frees2[1]) == 'PetscCall(ISDestroy(&(fields0[1])));' + assert str(frees2[2]) == 'PetscCall(ISDestroy(&(fields0[2])));' + assert str(frees2[3]) == 'PetscCall(PetscFree(fields0));' + # Sub DM destroys + assert str(frees2[4]) == 'PetscCall(DMDestroy(&(subdms0[0])));' + assert str(frees2[5]) == 'PetscCall(DMDestroy(&(subdms0[1])));' + assert str(frees2[6]) == 'PetscCall(DMDestroy(&(subdms0[2])));' + assert str(frees2[7]) == 'PetscCall(PetscFree(subdms0));' + + @skipif('petsc') + def test_dmda_dofs(self): + grid = Grid(shape=(11, 11), dtype=np.float64) + + functions = [Function(name=n, grid=grid, space_order=2) + for n in ['e', 'f', 'g', 'h']] + e, f, g, h = functions + + eq1 = Eq(e.laplace, h) + eq2 = Eq(f.laplace, h) + eq3 = Eq(g.laplace, h) + + petsc1 = PETScSolve({e: [eq1]}) + petsc2 = PETScSolve({e: [eq1], f: [eq2]}) + petsc3 = PETScSolve({e: [eq1], f: [eq2], g: [eq3]}) + + with switchconfig(openmp=False): + op1 = Operator(petsc1, opt='noop') + op2 = Operator(petsc2, opt='noop') + op3 = Operator(petsc3, opt='noop') + + # Check the number of dofs in the DMDA for each field + assert 'PetscCall(DMDACreate2d(PETSC_COMM_WORLD,DM_BOUNDARY_GHOSTED,' + \ + 'DM_BOUNDARY_GHOSTED,DMDA_STENCIL_BOX,11,11,1,1,1,2,NULL,NULL,&(da0)));' \ + in str(op1) + + assert 'PetscCall(DMDACreate2d(PETSC_COMM_WORLD,DM_BOUNDARY_GHOSTED,' + \ + 'DM_BOUNDARY_GHOSTED,DMDA_STENCIL_BOX,11,11,1,1,2,2,NULL,NULL,&(da0)));' \ + in str(op2) + + assert 'PetscCall(DMDACreate2d(PETSC_COMM_WORLD,DM_BOUNDARY_GHOSTED,' + \ + 'DM_BOUNDARY_GHOSTED,DMDA_STENCIL_BOX,11,11,1,1,3,2,NULL,NULL,&(da0)));' \ + in str(op3) + + @skipif('petsc') + def test_submatrices(self): + grid = Grid(shape=(11, 11), dtype=np.float64) + + functions = [Function(name=n, grid=grid, space_order=2) + for n in ['e', 'f', 'g', 'h']] + e, f, g, h = functions + + eq1 = Eq(e.laplace, f) + eq2 = Eq(g.laplace, h) + + petsc = PETScSolve({e: [eq1], g: [eq2]}) + + submatrices = petsc[0].rhs.fielddata.submatrices + + j00 = submatrices.get_submatrix(e, 'J00') + j01 = submatrices.get_submatrix(e, 'J01') + j10 = submatrices.get_submatrix(g, 'J10') + j11 = submatrices.get_submatrix(g, 'J11') + + # Check the number of submatrices + assert len(submatrices.submatrix_keys) == 4 + assert str(submatrices.submatrix_keys) == "['J00', 'J01', 'J10', 'J11']" + + # Technically a non-coupled problem, so the only non-zero submatrices + # should be the diagonal ones i.e J00 and J11 + assert submatrices.nonzero_submatrix_keys == ['J00', 'J11'] + assert submatrices.get_submatrix(e, 'J01')['matvecs'] is None + assert submatrices.get_submatrix(g, 'J10')['matvecs'] is None + + j00 = submatrices.get_submatrix(e, 'J00') + j11 = submatrices.get_submatrix(g, 'J11') + + assert str(j00['matvecs'][0]) == 'Eq(y_e(x, y),' \ + + ' Derivative(x_e(x, y), (x, 2)) + Derivative(x_e(x, y), (y, 2)))' + + assert str(j11['matvecs'][0]) == 'Eq(y_g(x, y),' \ + + ' Derivative(x_g(x, y), (x, 2)) + Derivative(x_g(x, y), (y, 2)))' + + # Check the derivative wrt fields + assert j00['derivative_wrt'] == e + assert j01['derivative_wrt'] == g + assert j10['derivative_wrt'] == e + assert j11['derivative_wrt'] == g + + # TODO: + # @skipif('petsc') + # def test_create_submats(self): + + # add tests for all new callbacks + # def test_create_whole_matmult(): diff --git a/tests/test_symbolics.py b/tests/test_symbolics.py index 2bae5679c8..febec6a25e 100644 --- a/tests/test_symbolics.py +++ b/tests/test_symbolics.py @@ -14,10 +14,11 @@ CallFromPointer, Cast, DefFunction, FieldFromPointer, INT, FieldFromComposite, IntDiv, Namespace, Rvalue, ReservedWord, ListInitializer, ccode, uxreplace, - retrieve_derivatives) + retrieve_derivatives, sympy_dtype) from devito.tools import as_tuple from devito.types import (Array, Bundle, FIndexed, LocalObject, Object, - ComponentAccess, StencilDimension, Symbol as dSymbol) + ComponentAccess, StencilDimension, Symbol as dSymbol, + CompositeObject) from devito.types.basic import AbstractSymbol @@ -249,6 +250,17 @@ def test_field_from_pointer(): # Free symbols assert ffp1.free_symbols == {s} + # Test dtype + f = dSymbol('f') + pfields = [(f._C_name, f._C_ctype)] + struct = CompositeObject('s1', 'myStruct', pfields) + ffp4 = FieldFromPointer(f, struct) + assert str(ffp4) == 's1->f' + assert ffp4.dtype == f.dtype + expr = 1/ffp4 + dtype = sympy_dtype(expr) + assert dtype == f.dtype + def test_field_from_composite(): s = Symbol('s') @@ -293,7 +305,8 @@ def test_extended_sympy_arithmetic(): # noncommutative o = Object(name='o', dtype=c_void_p) bar = FieldFromPointer('bar', o) - assert ccode(-1 + bar) == '-1 + o->bar' + # TODO: Edit/fix/update according to PR #2513 + assert ccode(-1 + bar) == 'o->bar - 1' def test_integer_abs():