diff --git a/.github/workflows/docker-petsc.yml b/.github/workflows/docker-petsc.yml new file mode 100644 index 0000000000..b0058423ce --- /dev/null +++ b/.github/workflows/docker-petsc.yml @@ -0,0 +1,56 @@ +name: Build CPU base for PETSc + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +on: + pull_request: + branches: + - master + workflow_dispatch: + schedule: + # Run once a week + - cron: "0 13 * * 1" + +jobs: +####################################################### +############## Basic gcc CPU ########################## +####################################################### + deploy-cpu-bases: + name: "cpu-base" + runs-on: ubuntu-latest + env: + DOCKER_BUILDKIT: "1" + + steps: + - name: Checkout devito + uses: actions/checkout@v4 + + - name: Check event name + run: echo ${{ github.event_name }} + + - name: Set up QEMU + uses: docker/setup-qemu-action@v2 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v2 + + - name: Login to DockerHub + uses: docker/login-action@v2 + with: + username: ${{ secrets.DOCKER_USERNAME }} + password: ${{ secrets.DOCKER_PASSWORD }} + + - name: cleanup + run: docker system prune -a -f + + - name: GCC image + uses: docker/build-push-action@v5 + with: + context: . + file: './docker/Dockerfile.cpu' + push: true + target: 'gcc' + build-args: 'arch=gcc' + tags: 'zoeleibowitz/bases:cpu-gcc' diff --git a/.github/workflows/pytest-petsc.yml b/.github/workflows/pytest-petsc.yml new file mode 100644 index 0000000000..6f0a729e95 --- /dev/null +++ b/.github/workflows/pytest-petsc.yml @@ -0,0 +1,74 @@ +name: CI-petsc + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +on: + # Trigger the workflow on push or pull request, + # but only for the master branch + push: + branches: + - master + pull_request: + branches: + - master + +jobs: + pytest: + name: ${{ matrix.name }}-${{ matrix.set }} + runs-on: "${{ matrix.os }}" + + env: + DOCKER_BUILDKIT: "1" + DEVITO_ARCH: "${{ matrix.arch }}" + DEVITO_LANGUAGE: ${{ matrix.language }} + + strategy: + # Prevent all build to stop if a single one fails + fail-fast: false + + matrix: + name: [ + pytest-docker-py39-gcc-noomp + ] + include: + - name: pytest-docker-py39-gcc-noomp + python-version: '3.9' + os: ubuntu-latest + arch: "gcc" + language: "C" + sympy: "1.12" + + steps: + - name: Checkout devito + uses: actions/checkout@v4 + + - name: Build docker image + run: | + docker build . --file docker/Dockerfile.devito --tag devito_img --build-arg base=zoeleibowitz/bases:cpu-${{ matrix.arch }} --build-arg petscinstall=petsc + + - name: Set run prefix + run: | + echo "RUN_CMD=docker run --rm -t -e CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }} --name testrun devito_img" >> $GITHUB_ENV + id: set-run + + - name: Set tests + run : | + echo "TESTS=tests/test_petsc.py" >> $GITHUB_ENV + id: set-tests + + - name: Check configuration + run: | + ${{ env.RUN_CMD }} python3 -c "from devito import configuration; print(''.join(['%s: %s \n' % (k, v) for (k, v) in configuration.items()]))" + + - name: Test with pytest + run: | + ${{ env.RUN_CMD }} pytest --cov --cov-config=.coveragerc --cov-report=xml ${{ env.TESTS }} + + - name: Upload coverage to Codecov + if: "!contains(matrix.name, 'docker')" + uses: codecov/codecov-action@v4 + with: + token: ${{ secrets.CODECOV_TOKEN }} + name: ${{ matrix.name }} diff --git a/conftest.py b/conftest.py index 4bb0629327..b804aaebc7 100644 --- a/conftest.py +++ b/conftest.py @@ -14,6 +14,7 @@ from devito.ir.iet import (FindNodes, FindSymbols, Iteration, ParallelBlock, retrieve_iteration_tree) from devito.tools import as_tuple +from devito.petsc.utils import get_petsc_dir, get_petsc_arch try: from mpi4py import MPI # noqa @@ -33,7 +34,7 @@ def skipif(items, whole_module=False): accepted = set() accepted.update({'device', 'device-C', 'device-openmp', 'device-openacc', 'device-aomp', 'cpu64-icc', 'cpu64-icx', 'cpu64-nvc', 'cpu64-arm', - 'cpu64-icpx', 'chkpnt'}) + 'cpu64-icpx', 'chkpnt', 'petsc'}) accepted.update({'nodevice'}) unknown = sorted(set(items) - accepted) if unknown: @@ -87,6 +88,19 @@ def skipif(items, whole_module=False): if i == 'chkpnt' and Revolver is NoopRevolver: skipit = "pyrevolve not installed" break + if i == 'petsc': + petsc_dir = get_petsc_dir() + petsc_arch = get_petsc_arch() + if petsc_dir is None or petsc_arch is None: + skipit = "PETSC_DIR or PETSC_ARCH are not set" + break + else: + petsc_installed = os.path.join( + petsc_dir, petsc_arch, 'include', 'petscconf.h' + ) + if not os.path.isfile(petsc_installed): + skipit = "PETSc is not installed" + break if skipit is False: return pytest.mark.skipif(False, reason='') diff --git a/devito/ir/equations/equation.py b/devito/ir/equations/equation.py index ada1c23f22..67ed9269a4 100644 --- a/devito/ir/equations/equation.py +++ b/devito/ir/equations/equation.py @@ -9,10 +9,12 @@ Stencil, detect_io, detect_accesses) from devito.symbolics import IntDiv, limits_mapper, uxreplace from devito.tools import Pickable, Tag, frozendict -from devito.types import Eq, Inc, ReduceMax, ReduceMin, relational_min +from devito.types import (Eq, Inc, ReduceMax, ReduceMin, + relational_min) +from devito.types.equation import InjectSolveEq __all__ = ['LoweredEq', 'ClusterizedEq', 'DummyEq', 'OpInc', 'OpMin', 'OpMax', - 'identity_mapper'] + 'identity_mapper', 'OpInjectSolve'] class IREq(sympy.Eq, Pickable): @@ -102,7 +104,8 @@ def detect(cls, expr): reduction_mapper = { Inc: OpInc, ReduceMax: OpMax, - ReduceMin: OpMin + ReduceMin: OpMin, + InjectSolveEq: OpInjectSolve } try: return reduction_mapper[type(expr)] @@ -119,6 +122,7 @@ def detect(cls, expr): OpInc = Operation('+') OpMax = Operation('max') OpMin = Operation('min') +OpInjectSolve = Operation('solve') identity_mapper = { diff --git a/devito/ir/iet/algorithms.py b/devito/ir/iet/algorithms.py index 0b57b876f7..9d2db185db 100644 --- a/devito/ir/iet/algorithms.py +++ b/devito/ir/iet/algorithms.py @@ -3,6 +3,8 @@ from devito.ir.iet import (Expression, Increment, Iteration, List, Conditional, SyncSpot, Section, HaloSpot, ExpressionBundle) from devito.tools import timed_pass +from devito.petsc.types import LinearSolveExpr +from devito.petsc.iet.utils import petsc_iet_mapper __all__ = ['iet_build'] @@ -24,6 +26,8 @@ def iet_build(stree): for e in i.exprs: if e.is_Increment: exprs.append(Increment(e)) + elif isinstance(e.rhs, LinearSolveExpr): + exprs.append(petsc_iet_mapper[e.operation](e, operation=e.operation)) else: exprs.append(Expression(e, operation=e.operation)) body = ExpressionBundle(i.ispace, i.ops, i.traffic, body=exprs) diff --git a/devito/ir/iet/efunc.py b/devito/ir/iet/efunc.py index e7192ecf9b..86eeac77e0 100644 --- a/devito/ir/iet/efunc.py +++ b/devito/ir/iet/efunc.py @@ -1,6 +1,6 @@ from functools import cached_property -from devito.ir.iet.nodes import Call, Callable +from devito.ir.iet.nodes import Call, Callable, FixedArgsCallable from devito.ir.iet.utils import derive_parameters from devito.symbolics import uxreplace from devito.tools import as_tuple @@ -131,7 +131,7 @@ class AsyncCall(Call): pass -class ThreadCallable(Callable): +class ThreadCallable(FixedArgsCallable): """ A Callable executed asynchronously by a thread. diff --git a/devito/ir/iet/nodes.py b/devito/ir/iet/nodes.py index 9bcc3460f6..ebe5f81f1a 100644 --- a/devito/ir/iet/nodes.py +++ b/devito/ir/iet/nodes.py @@ -1,6 +1,7 @@ """The Iteration/Expression Tree (IET) hierarchy.""" import abc +import ctypes import inspect from functools import cached_property from collections import OrderedDict, namedtuple @@ -28,7 +29,7 @@ 'Increment', 'Return', 'While', 'ListMajor', 'ParallelIteration', 'ParallelBlock', 'Dereference', 'Lambda', 'SyncSpot', 'Pragma', 'DummyExpr', 'BlankLine', 'ParallelTree', 'BusyWait', 'UsingNamespace', - 'CallableBody', 'Transfer'] + 'CallableBody', 'Transfer', 'Callback', 'FixedArgsCallable'] # First-class IET nodes @@ -750,6 +751,15 @@ def defines(self): return self.all_parameters +class FixedArgsCallable(Callable): + + """ + A Callable class that enforces a fixed function signature. + """ + + pass + + class CallableBody(MultiTraversable): """ @@ -1028,8 +1038,8 @@ class Dereference(ExprStmt, Node): The following cases are supported: * `pointer` is a PointerArray or TempFunction, and `pointee` is an Array. - * `pointer` is an ArrayObject representing a pointer to a C struct, and - `pointee` is a field in `pointer`. + * `pointer` is an ArrayObject or CCompositeObject representing a pointer + to a C struct, and `pointee` is a field in `pointer`. """ is_Dereference = True @@ -1048,13 +1058,14 @@ def functions(self): @property def expr_symbols(self): - ret = [self.pointer.indexed] + ret = [] if self.pointer.is_PointerArray or self.pointer.is_TempFunction: - ret.append(self.pointee.indexed) + ret.extend([self.pointer.indexed, self.pointee.indexed]) ret.extend(flatten(i.free_symbols for i in self.pointee.symbolic_shape[1:])) ret.extend(self.pointer.free_symbols) else: - ret.append(self.pointee._C_symbol) + assert issubclass(self.pointer._C_ctype, ctypes._Pointer) + ret.extend([self.pointer._C_symbol, self.pointee._C_symbol]) return tuple(filter_ordered(ret)) @property @@ -1120,6 +1131,45 @@ def defines(self): return tuple(self.parameters) +class Callback(Call): + """ + Base class for special callback types. + + Parameters + ---------- + name : str + The name of the callback. + retval : str + The return type of the callback. + param_types : str or list of str + The return type for each argument of the callback. + + Notes + ----- + - The reason Callback is an IET type rather than a SymPy type is + due to the fact that, when represented at the SymPy level, the IET + engine fails to bind the callback to a specific Call. Consequently, + errors occur during the creation of the call graph. + """ + # TODO: Create a common base class for Call and Callback to avoid + # having arguments=None here + def __init__(self, name, retval=None, param_types=None, arguments=None): + super().__init__(name=name) + self.retval = retval + self.param_types = as_tuple(param_types) + + @property + def callback_form(self): + """ + A string representation of the callback form. + + Notes + ----- + To be overridden by subclasses. + """ + return + + class Section(List): """ diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 505fe2e001..9359b2c3c0 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -253,7 +253,7 @@ def _gen_value(self, obj, mode=1, masked=()): strtype = '%s%s' % (strtype, self._restrict_keyword) strtype = ' '.join(qualifiers + [strtype]) - if obj.is_LocalObject and obj._C_modifier is not None and mode == 2: + if obj.is_LocalType and obj._C_modifier is not None and mode == 2: strtype += obj._C_modifier strname = obj._C_name @@ -613,6 +613,9 @@ def visit_Lambda(self, o): (', '.join(captures), ', '.join(decls), ''.join(extra))) return LambdaCollection([top, c.Block(body)]) + def visit_Callback(self, o, nested_call=False): + return CallbackArg(o) + def visit_HaloSpot(self, o): body = flatten(self._visit(i) for i in o.children) return c.Collection(body) @@ -1416,3 +1419,12 @@ def sorted_efuncs(efuncs): CommCallable: 1 } return sorted_priority(efuncs, priority) + + +class CallbackArg(c.Generable): + + def __init__(self, callback): + self.callback = callback + + def generate(self): + yield self.callback.callback_form diff --git a/devito/operator/operator.py b/devito/operator/operator.py index 24258ad671..3172b90ae4 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -31,6 +31,8 @@ split, timed_pass, timed_region, contains_val) from devito.types import (Buffer, Grid, Evaluable, host_layer, device_layer, disk_layer) +from devito.petsc.iet.passes import lower_petsc +from devito.petsc.clusters import petsc_preprocess __all__ = ['Operator'] @@ -374,6 +376,9 @@ def _lower_clusters(cls, expressions, profiler=None, **kwargs): # Build a sequence of Clusters from a sequence of Eqs clusters = clusterize(expressions, **kwargs) + # Preprocess clusters for PETSc lowering + clusters = petsc_preprocess(clusters) + # Operation count before specialization init_ops = sum(estimate_cost(c.exprs) for c in clusters if c.is_dense) @@ -471,6 +476,9 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs): # Lower IET to a target-specific IET graph = Graph(iet, **kwargs) + + lower_petsc(graph, **kwargs) + graph = cls._specialize_iet(graph, **kwargs) # Instrument the IET for C-level profiling @@ -502,7 +510,7 @@ def dimensions(self): # During compilation other Dimensions may have been produced dimensions = FindSymbols('dimensions').visit(self) - ret.update(d for d in dimensions if d.is_PerfKnob) + ret.update(dimensions) ret = tuple(sorted(ret, key=attrgetter('name'))) diff --git a/devito/passes/iet/definitions.py b/devito/passes/iet/definitions.py index 28981d3bf1..557990032c 100644 --- a/devito/passes/iet/definitions.py +++ b/devito/passes/iet/definitions.py @@ -17,7 +17,8 @@ from devito.symbolics import (Byref, DefFunction, FieldFromPointer, IndexedPointer, SizeOf, VOID, Keyword, pow_to_mul) from devito.tools import as_mapper, as_list, as_tuple, filter_sorted, flatten -from devito.types import Array, CustomDimension, DeviceMap, DeviceRM, Eq, Symbol +from devito.types import (Array, CustomDimension, DeviceMap, DeviceRM, Eq, Symbol, + IndexedData) __all__ = ['DataManager', 'DeviceAwareDataManager', 'Storage'] @@ -295,6 +296,10 @@ def _inject_definitions(self, iet, storage): init = self.lang['thread-num'](retobj=tid) frees.append(Block(header=header, body=[init] + body)) frees.extend(as_list(cbody.frees) + flatten(v.frees)) + frees = sorted(frees, key=lambda x: min( + (obj._C_free_priority for obj in FindSymbols().visit(x) + if obj.is_LocalObject), default=float('inf') + )) # maps/unmaps maps = as_list(cbody.maps) + flatten(v.maps) @@ -355,10 +360,10 @@ def place_definitions(self, iet, globs=None, **kwargs): elif globs is not None: # Track, to be handled by the EntryFunction being a global obj! globs.add(i) - elif i.is_ObjectArray: - self._alloc_object_array_on_low_lat_mem(iet, i, storage) elif i.is_PointerArray: self._alloc_pointed_array_on_high_bw_mem(iet, i, storage) + else: + self._alloc_object_array_on_low_lat_mem(iet, i, storage) # Handle postponed global objects includes = set() @@ -394,7 +399,8 @@ def place_casts(self, iet, **kwargs): # (i) Dereferencing a PointerArray, e.g., `float (*r0)[.] = (float(*)[.]) pr0[.]` # (ii) Declaring a raw pointer, e.g., `float * r0 = NULL; *malloc(&(r0), ...) defines = set(FindSymbols('defines|globals').visit(iet)) - bases = sorted({i.base for i in indexeds}, key=lambda i: i.name) + bases = sorted({i.base for i in indexeds + if isinstance(i.base, IndexedData)}, key=lambda i: i.name) # Some objects don't distinguish their _C_symbol because they are known, # by construction, not to require it, thus making the generated code diff --git a/devito/passes/iet/engine.py b/devito/passes/iet/engine.py index 7221e985d5..cb36d18541 100644 --- a/devito/passes/iet/engine.py +++ b/devito/passes/iet/engine.py @@ -3,7 +3,7 @@ from devito.ir.iet import (Call, ExprStmt, Iteration, SyncSpot, AsyncCallable, FindNodes, FindSymbols, MapNodes, MetaCall, Transformer, - EntryFunction, ThreadCallable, Uxreplace, + EntryFunction, FixedArgsCallable, Uxreplace, derive_parameters) from devito.ir.support import SymbolRegistry from devito.mpi.distributed import MPINeighborhood @@ -129,6 +129,7 @@ def apply(self, func, **kwargs): compiler.add_libraries(as_tuple(metadata.get('libs'))) compiler.add_library_dirs(as_tuple(metadata.get('lib_dirs')), rpath=metadata.get('rpath', False)) + compiler.add_ldflags(as_tuple(metadata.get('ldflags'))) except KeyError: pass @@ -601,12 +602,12 @@ def update_args(root, efuncs, dag): foo(..., z) : root(x, z) """ - if isinstance(root, ThreadCallable): + if isinstance(root, FixedArgsCallable): return efuncs # The parameters/arguments lists may have changed since a pass may have: # 1) introduced a new symbol - new_params = derive_parameters(root) + new_params = derive_parameters(root, drop_locals=True) # 2) defined a symbol for which no definition was available yet (e.g. # via a malloc, or a Dereference) diff --git a/devito/passes/iet/mpi.py b/devito/passes/iet/mpi.py index 9f49440709..e590de3d4e 100644 --- a/devito/passes/iet/mpi.py +++ b/devito/passes/iet/mpi.py @@ -79,7 +79,9 @@ def rule1(dep, candidates, loc_dims): rules = [rule0, rule1] # Precompute scopes to save time - scopes = {i: Scope([e.expr for e in v]) for i, v in MapNodes().visit(iet).items()} + scopes = {} + for i, v in MapNodes(child_types=Expression).visit(iet).items(): + scopes[i] = Scope([e.expr for e in v]) # Analysis hsmapper = {} diff --git a/devito/petsc/__init__.py b/devito/petsc/__init__.py new file mode 100644 index 0000000000..2927bc960e --- /dev/null +++ b/devito/petsc/__init__.py @@ -0,0 +1 @@ +from devito.petsc.solve import * # noqa diff --git a/devito/petsc/clusters.py b/devito/petsc/clusters.py new file mode 100644 index 0000000000..0a6679b2dd --- /dev/null +++ b/devito/petsc/clusters.py @@ -0,0 +1,27 @@ +from devito.tools import timed_pass +from devito.petsc.types import LinearSolveExpr + + +@timed_pass() +def petsc_preprocess(clusters): + """ + Preprocess the clusters to make them suitable for PETSc + code generation. + """ + clusters = petsc_lift(clusters) + return clusters + + +def petsc_lift(clusters): + """ + Lift the iteration space surrounding each PETSc solve to create + distinct iteration loops. + """ + processed = [] + for c in clusters: + if isinstance(c.exprs[0].rhs, LinearSolveExpr): + ispace = c.ispace.lift(c.exprs[0].rhs.target.space_dimensions) + processed.append(c.rebuild(ispace=ispace)) + else: + processed.append(c) + return processed diff --git a/devito/petsc/iet/__init__.py b/devito/petsc/iet/__init__.py new file mode 100644 index 0000000000..beb6d1f2d1 --- /dev/null +++ b/devito/petsc/iet/__init__.py @@ -0,0 +1 @@ +from devito.petsc.iet import * # noqa diff --git a/devito/petsc/iet/nodes.py b/devito/petsc/iet/nodes.py new file mode 100644 index 0000000000..2137495487 --- /dev/null +++ b/devito/petsc/iet/nodes.py @@ -0,0 +1,39 @@ +from devito.ir.iet import Expression, Callback, FixedArgsCallable, Call +from devito.ir.equations import OpInjectSolve + + +class LinearSolverExpression(Expression): + """ + Base class for general expressions required by a + matrix-free linear solve of the form Ax=b. + """ + pass + + +class InjectSolveDummy(LinearSolverExpression): + """ + Placeholder expression to run the iterative solver. + """ + def __init__(self, expr, pragmas=None, operation=OpInjectSolve): + super().__init__(expr, pragmas=pragmas, operation=operation) + + +class PETScCallable(FixedArgsCallable): + pass + + +class MatVecCallback(Callback): + @property + def callback_form(self): + param_types_str = ', '.join([str(t) for t in self.param_types]) + return "(%s (*)(%s))%s" % (self.retval, param_types_str, self.name) + + +class FormFunctionCallback(Callback): + @property + def callback_form(self): + return "%s" % self.name + + +class PETScCall(Call): + pass diff --git a/devito/petsc/iet/passes.py b/devito/petsc/iet/passes.py new file mode 100644 index 0000000000..835ad102fb --- /dev/null +++ b/devito/petsc/iet/passes.py @@ -0,0 +1,310 @@ +import cgen as c + +from devito.passes.iet.engine import iet_pass +from devito.ir.iet import (Transformer, MapNodes, Iteration, List, BlankLine, + DummyExpr, FindNodes, retrieve_iteration_tree, + filter_iterations) +from devito.symbolics import Byref, Macro, FieldFromComposite +from devito.petsc.types import (PetscMPIInt, DM, Mat, LocalVec, GlobalVec, + KSP, PC, SNES, PetscErrorCode, DummyArg, PetscInt, + StartPtr) +from devito.petsc.iet.nodes import InjectSolveDummy, PETScCall +from devito.petsc.utils import solver_mapper, core_metadata +from devito.petsc.iet.routines import PETScCallbackBuilder +from devito.petsc.iet.utils import petsc_call, petsc_call_mpi + + +@iet_pass +def lower_petsc(iet, **kwargs): + # Check if PETScSolve was used + injectsolve_mapper = MapNodes(Iteration, InjectSolveDummy, + 'groupby').visit(iet) + + if not injectsolve_mapper: + return iet, {} + + targets = [i.expr.rhs.target for (i,) in injectsolve_mapper.values()] + init = init_petsc(**kwargs) + + # Assumption is that all targets have the same grid so can use any target here + objs = build_core_objects(targets[-1], **kwargs) + + # Create core PETSc calls (not specific to each PETScSolve) + core = make_core_petsc_calls(objs, **kwargs) + + setup = [] + subs = {} + + # Create a different DMDA for each target with a unique space order + unique_dmdas = create_dmda_objs(targets) + objs.update(unique_dmdas) + for dmda in unique_dmdas.values(): + setup.extend(create_dmda_calls(dmda, objs)) + + builder = PETScCallbackBuilder(**kwargs) + + for iters, (injectsolve,) in injectsolve_mapper.items(): + solver_objs = build_solver_objs(injectsolve, iters, **kwargs) + + # Generate the solver setup for each InjectSolveDummy + solver_setup = generate_solver_setup(solver_objs, objs, injectsolve) + setup.extend(solver_setup) + + # Generate all PETSc callback functions for the target via recursive compilation + matvec_op, formfunc_op, runsolve = builder.make(injectsolve, + objs, solver_objs) + setup.extend([matvec_op, formfunc_op, BlankLine]) + # Only Transform the spatial iteration loop + space_iter, = spatial_injectsolve_iter(iters, injectsolve) + subs.update({space_iter: List(body=runsolve)}) + + # Generate callback to populate main struct object + struct, struct_calls = builder.make_main_struct(unique_dmdas, objs) + setup.extend(struct_calls) + + iet = Transformer(subs).visit(iet) + + iet = assign_time_iters(iet, struct) + + body = core + tuple(setup) + (BlankLine,) + iet.body.body + body = iet.body._rebuild( + init=init, body=body, + frees=(c.Line("PetscCall(PetscFinalize());"),) + ) + iet = iet._rebuild(body=body) + metadata = core_metadata() + efuncs = tuple(builder.efuncs.values()) + metadata.update({'efuncs': efuncs}) + + return iet, metadata + + +def init_petsc(**kwargs): + # Initialize PETSc -> for now, assuming all solver options have to be + # specifed via the parameters dict in PETScSolve + # TODO: Are users going to be able to use PETSc command line arguments? + # In firedrake, they have an options_prefix for each solver, enabling the use + # of command line options + initialize = petsc_call('PetscInitialize', [Null, Null, Null, Null]) + + return petsc_func_begin_user, initialize + + +def make_core_petsc_calls(objs, **kwargs): + call_mpi = petsc_call_mpi('MPI_Comm_size', [objs['comm'], Byref(objs['size'])]) + + return call_mpi, BlankLine + + +def build_core_objects(target, **kwargs): + if kwargs['options']['mpi']: + communicator = target.grid.distributor._obj_comm + else: + communicator = 'PETSC_COMM_SELF' + + return { + 'size': PetscMPIInt(name='size'), + 'comm': communicator, + 'err': PetscErrorCode(name='err'), + 'grid': target.grid + } + + +def create_dmda_objs(unique_targets): + unique_dmdas = {} + for target in unique_targets: + name = 'da_so_%s' % target.space_order + unique_dmdas[name] = DM(name=name, liveness='eager', + stencil_width=target.space_order) + return unique_dmdas + + +def create_dmda_calls(dmda, objs): + dmda_create = create_dmda(dmda, objs) + dm_setup = petsc_call('DMSetUp', [dmda]) + dm_mat_type = petsc_call('DMSetMatType', [dmda, 'MATSHELL']) + dm_get_local_info = petsc_call('DMDAGetLocalInfo', [dmda, Byref(dmda.info)]) + return dmda_create, dm_setup, dm_mat_type, dm_get_local_info, BlankLine + + +def create_dmda(dmda, objs): + no_of_space_dims = len(objs['grid'].dimensions) + + # MPI communicator + args = [objs['comm']] + + # Type of ghost nodes + args.extend(['DM_BOUNDARY_GHOSTED' for _ in range(no_of_space_dims)]) + + # Stencil type + if no_of_space_dims > 1: + args.append('DMDA_STENCIL_BOX') + + # Global dimensions + args.extend(list(objs['grid'].shape)[::-1]) + # No.of processors in each dimension + if no_of_space_dims > 1: + args.extend(list(objs['grid'].distributor.topology)[::-1]) + + # Number of degrees of freedom per node + args.append(1) + # "Stencil width" -> size of overlap + args.append(dmda.stencil_width) + args.extend([Null for _ in range(no_of_space_dims)]) + + # The distributed array object + args.append(Byref(dmda)) + + # The PETSc call used to create the DMDA + dmda = petsc_call('DMDACreate%sd' % no_of_space_dims, args) + + return dmda + + +def build_solver_objs(injectsolve, iters, **kwargs): + target = injectsolve.expr.rhs.target + sreg = kwargs['sregistry'] + return { + 'Jac': Mat(sreg.make_name(prefix='J_')), + 'x_global': GlobalVec(sreg.make_name(prefix='x_global_')), + 'x_local': LocalVec(sreg.make_name(prefix='x_local_'), liveness='eager'), + 'b_global': GlobalVec(sreg.make_name(prefix='b_global_')), + 'b_local': LocalVec(sreg.make_name(prefix='b_local_')), + 'ksp': KSP(sreg.make_name(prefix='ksp_')), + 'pc': PC(sreg.make_name(prefix='pc_')), + 'snes': SNES(sreg.make_name(prefix='snes_')), + 'X_global': GlobalVec(sreg.make_name(prefix='X_global_')), + 'Y_global': GlobalVec(sreg.make_name(prefix='Y_global_')), + 'X_local': LocalVec(sreg.make_name(prefix='X_local_'), liveness='eager'), + 'Y_local': LocalVec(sreg.make_name(prefix='Y_local_'), liveness='eager'), + 'dummy': DummyArg(sreg.make_name(prefix='dummy_')), + 'localsize': PetscInt(sreg.make_name(prefix='localsize_')), + 'start_ptr': StartPtr(sreg.make_name(prefix='start_ptr_'), target.dtype), + 'true_dims': retrieve_time_dims(iters), + 'target': target, + 'time_mapper': injectsolve.expr.rhs.time_mapper, + } + + +def generate_solver_setup(solver_objs, objs, injectsolve): + target = solver_objs['target'] + + dmda = objs['da_so_%s' % target.space_order] + + solver_params = injectsolve.expr.rhs.solver_parameters + + snes_create = petsc_call('SNESCreate', [objs['comm'], Byref(solver_objs['snes'])]) + + snes_set_dm = petsc_call('SNESSetDM', [solver_objs['snes'], dmda]) + + create_matrix = petsc_call('DMCreateMatrix', [dmda, Byref(solver_objs['Jac'])]) + + # NOTE: Assumming all solves are linear for now. + snes_set_type = petsc_call('SNESSetType', [solver_objs['snes'], 'SNESKSPONLY']) + + snes_set_jac = petsc_call( + 'SNESSetJacobian', [solver_objs['snes'], solver_objs['Jac'], + solver_objs['Jac'], 'MatMFFDComputeJacobian', Null] + ) + + global_x = petsc_call('DMCreateGlobalVector', + [dmda, Byref(solver_objs['x_global'])]) + + global_b = petsc_call('DMCreateGlobalVector', + [dmda, Byref(solver_objs['b_global'])]) + + local_b = petsc_call('DMCreateLocalVector', + [dmda, Byref(solver_objs['b_local'])]) + + snes_get_ksp = petsc_call('SNESGetKSP', + [solver_objs['snes'], Byref(solver_objs['ksp'])]) + + ksp_set_tols = petsc_call( + 'KSPSetTolerances', [solver_objs['ksp'], solver_params['ksp_rtol'], + solver_params['ksp_atol'], solver_params['ksp_divtol'], + solver_params['ksp_max_it']] + ) + + ksp_set_type = petsc_call( + 'KSPSetType', [solver_objs['ksp'], solver_mapper[solver_params['ksp_type']]] + ) + + ksp_get_pc = petsc_call('KSPGetPC', [solver_objs['ksp'], Byref(solver_objs['pc'])]) + + # Even though the default will be jacobi, set to PCNONE for now + pc_set_type = petsc_call('PCSetType', [solver_objs['pc'], 'PCNONE']) + + ksp_set_from_ops = petsc_call('KSPSetFromOptions', [solver_objs['ksp']]) + + return ( + snes_create, + snes_set_dm, + create_matrix, + snes_set_jac, + snes_set_type, + global_x, + global_b, + local_b, + snes_get_ksp, + ksp_set_tols, + ksp_set_type, + ksp_get_pc, + pc_set_type, + ksp_set_from_ops + ) + + +def assign_time_iters(iet, struct): + """ + Assign time iterators to the struct within loops containing PETScCalls. + Ensure that assignment occurs only once per time loop, if necessary. + Assign only the iterators that are common between the struct fields + and the actual Iteration. + """ + time_iters = [ + i for i in FindNodes(Iteration).visit(iet) + if i.dim.is_Time and FindNodes(PETScCall).visit(i) + ] + + if not time_iters: + return iet + + mapper = {} + for iter in time_iters: + common_dims = [dim for dim in iter.dimensions if dim in struct.fields] + common_dims = [ + DummyExpr(FieldFromComposite(dim, struct), dim) for dim in common_dims + ] + iter_new = iter._rebuild(nodes=List(body=tuple(common_dims)+iter.nodes)) + mapper.update({iter: iter_new}) + + return Transformer(mapper).visit(iet) + + +def retrieve_time_dims(iters): + time_iter = [i for i in iters if any(dim.is_Time for dim in i.dimensions)] + mapper = {} + if not time_iter: + return mapper + for dim in time_iter[0].dimensions: + if dim.is_Modulo: + mapper[dim.origin] = dim + elif dim.is_Time: + mapper[dim] = dim + return mapper + + +def spatial_injectsolve_iter(iter, injectsolve): + spatial_body = [] + for tree in retrieve_iteration_tree(iter[0]): + root = filter_iterations(tree, key=lambda i: i.dim.is_Space)[0] + if injectsolve in FindNodes(InjectSolveDummy).visit(root): + spatial_body.append(root) + return spatial_body + + +Null = Macro('NULL') +void = 'void' + +# TODO: Don't use c.Line here? +petsc_func_begin_user = c.Line('PetscFunctionBeginUser;') diff --git a/devito/petsc/iet/routines.py b/devito/petsc/iet/routines.py new file mode 100644 index 0000000000..a516b1bc81 --- /dev/null +++ b/devito/petsc/iet/routines.py @@ -0,0 +1,525 @@ +from collections import OrderedDict + +import cgen as c + +from devito.ir.iet import (Call, FindSymbols, List, Uxreplace, CallableBody, + Dereference, DummyExpr, BlankLine, Callable) +from devito.symbolics import Byref, FieldFromPointer, Macro, cast_mapper +from devito.symbolics.unevaluation import Mul +from devito.types.basic import AbstractFunction +from devito.types import ModuloDimension, TimeDimension, Temp +from devito.tools import filter_ordered +from devito.petsc.types import PETScArray +from devito.petsc.iet.nodes import (PETScCallable, FormFunctionCallback, + MatVecCallback) +from devito.petsc.iet.utils import petsc_call, petsc_struct +from devito.ir.support import SymbolRegistry + + +class PETScCallbackBuilder: + """ + Build IET routines to generate PETSc callback functions. + """ + def __new__(cls, rcompile=None, sregistry=None, **kwargs): + obj = object.__new__(cls) + obj.rcompile = rcompile + obj.sregistry = sregistry + obj._efuncs = OrderedDict() + obj._struct_params = [] + + return obj + + @property + def efuncs(self): + return self._efuncs + + @property + def struct_params(self): + return self._struct_params + + def make(self, injectsolve, objs, solver_objs): + matvec_callback, formfunc_callback, formrhs_callback = self.make_all( + injectsolve, objs, solver_objs + ) + + matvec_operation = petsc_call( + 'MatShellSetOperation', [solver_objs['Jac'], 'MATOP_MULT', + MatVecCallback(matvec_callback.name, void, void)] + ) + formfunc_operation = petsc_call( + 'SNESSetFunction', + [solver_objs['snes'], Null, + FormFunctionCallback(formfunc_callback.name, void, void), Null] + ) + runsolve = self.runsolve(solver_objs, objs, formrhs_callback, injectsolve) + + return matvec_operation, formfunc_operation, runsolve + + def make_all(self, injectsolve, objs, solver_objs): + matvec_callback = self.make_matvec(injectsolve, objs, solver_objs) + formfunc_callback = self.make_formfunc(injectsolve, objs, solver_objs) + formrhs_callback = self.make_formrhs(injectsolve, objs, solver_objs) + + self._efuncs[matvec_callback.name] = matvec_callback + self._efuncs[formfunc_callback.name] = formfunc_callback + self._efuncs[formrhs_callback.name] = formrhs_callback + + return matvec_callback, formfunc_callback, formrhs_callback + + def make_matvec(self, injectsolve, objs, solver_objs): + # Compile matvec `eqns` into an IET via recursive compilation + irs_matvec, _ = self.rcompile(injectsolve.expr.rhs.matvecs, + options={'mpi': False}, sregistry=SymbolRegistry()) + body_matvec = self.create_matvec_body(injectsolve, + List(body=irs_matvec.uiet.body), + solver_objs, objs) + + matvec_callback = PETScCallable( + self.sregistry.make_name(prefix='MyMatShellMult_'), body_matvec, + retval=objs['err'], + parameters=( + solver_objs['Jac'], solver_objs['X_global'], solver_objs['Y_global'] + ) + ) + return matvec_callback + + def create_matvec_body(self, injectsolve, body, solver_objs, objs): + linsolveexpr = injectsolve.expr.rhs + + dmda = objs['da_so_%s' % linsolveexpr.target.space_order] + + body = uxreplace_time(body, solver_objs) + + struct = build_local_struct(body, 'matvec', liveness='eager') + + y_matvec = linsolveexpr.arrays['y_matvec'] + x_matvec = linsolveexpr.arrays['x_matvec'] + + mat_get_dm = petsc_call('MatGetDM', [solver_objs['Jac'], Byref(dmda)]) + + dm_get_app_context = petsc_call( + 'DMGetApplicationContext', [dmda, Byref(struct._C_symbol)] + ) + + dm_get_local_xvec = petsc_call( + 'DMGetLocalVector', [dmda, Byref(solver_objs['X_local'])] + ) + + global_to_local_begin = petsc_call( + 'DMGlobalToLocalBegin', [dmda, solver_objs['X_global'], + 'INSERT_VALUES', solver_objs['X_local']] + ) + + global_to_local_end = petsc_call('DMGlobalToLocalEnd', [ + dmda, solver_objs['X_global'], 'INSERT_VALUES', solver_objs['X_local'] + ]) + + dm_get_local_yvec = petsc_call( + 'DMGetLocalVector', [dmda, Byref(solver_objs['Y_local'])] + ) + + vec_get_array_y = petsc_call( + 'VecGetArray', [solver_objs['Y_local'], Byref(y_matvec._C_symbol)] + ) + + vec_get_array_x = petsc_call( + 'VecGetArray', [solver_objs['X_local'], Byref(x_matvec._C_symbol)] + ) + + dm_get_local_info = petsc_call( + 'DMDAGetLocalInfo', [dmda, Byref(dmda.info)] + ) + + vec_restore_array_y = petsc_call( + 'VecRestoreArray', [solver_objs['Y_local'], Byref(y_matvec._C_symbol)] + ) + + vec_restore_array_x = petsc_call( + 'VecRestoreArray', [solver_objs['X_local'], Byref(x_matvec._C_symbol)] + ) + + dm_local_to_global_begin = petsc_call('DMLocalToGlobalBegin', [ + dmda, solver_objs['Y_local'], 'INSERT_VALUES', solver_objs['Y_global'] + ]) + + dm_local_to_global_end = petsc_call('DMLocalToGlobalEnd', [ + dmda, solver_objs['Y_local'], 'INSERT_VALUES', solver_objs['Y_global'] + ]) + + # TODO: Some of the calls are placed in the `stacks` argument of the + # `CallableBody` to ensure that they precede the `cast` statements. The + # 'casts' depend on the calls, so this order is necessary. By doing this, + # you avoid having to manually construct the `casts` and can allow + # Devito to handle their construction. This is a temporary solution and + # should be revisited + + body = body._rebuild( + body=body.body + + (vec_restore_array_y, + vec_restore_array_x, + dm_local_to_global_begin, + dm_local_to_global_end) + ) + + stacks = ( + mat_get_dm, + dm_get_app_context, + dm_get_local_xvec, + global_to_local_begin, + global_to_local_end, + dm_get_local_yvec, + vec_get_array_y, + vec_get_array_x, + dm_get_local_info + ) + + # Dereference function data in struct + dereference_funcs = [Dereference(i, struct) for i in + struct.fields if isinstance(i.function, AbstractFunction)] + + matvec_body = CallableBody( + List(body=body), + init=(petsc_func_begin_user,), + stacks=stacks+tuple(dereference_funcs), + retstmt=(Call('PetscFunctionReturn', arguments=[0]),) + ) + + # Replace non-function data with pointer to data in struct + subs = {i._C_symbol: FieldFromPointer(i._C_symbol, struct) for i in struct.fields} + matvec_body = Uxreplace(subs).visit(matvec_body) + + self._struct_params.extend(struct.fields) + + return matvec_body + + def make_formfunc(self, injectsolve, objs, solver_objs): + # Compile formfunc `eqns` into an IET via recursive compilation + irs_formfunc, _ = self.rcompile( + injectsolve.expr.rhs.formfuncs, + options={'mpi': False}, sregistry=SymbolRegistry() + ) + body_formfunc = self.create_formfunc_body(injectsolve, + List(body=irs_formfunc.uiet.body), + solver_objs, objs) + + formfunc_callback = PETScCallable( + self.sregistry.make_name(prefix='FormFunction_'), body_formfunc, + retval=objs['err'], + parameters=(solver_objs['snes'], solver_objs['X_global'], + solver_objs['Y_global'], solver_objs['dummy']) + ) + return formfunc_callback + + def create_formfunc_body(self, injectsolve, body, solver_objs, objs): + linsolveexpr = injectsolve.expr.rhs + + dmda = objs['da_so_%s' % linsolveexpr.target.space_order] + + body = uxreplace_time(body, solver_objs) + + struct = build_local_struct(body, 'formfunc', liveness='eager') + + y_formfunc = linsolveexpr.arrays['y_formfunc'] + x_formfunc = linsolveexpr.arrays['x_formfunc'] + + snes_get_dm = petsc_call('SNESGetDM', [solver_objs['snes'], Byref(dmda)]) + + dm_get_app_context = petsc_call( + 'DMGetApplicationContext', [dmda, Byref(struct._C_symbol)] + ) + + dm_get_local_xvec = petsc_call( + 'DMGetLocalVector', [dmda, Byref(solver_objs['X_local'])] + ) + + global_to_local_begin = petsc_call( + 'DMGlobalToLocalBegin', [dmda, solver_objs['X_global'], + 'INSERT_VALUES', solver_objs['X_local']] + ) + + global_to_local_end = petsc_call('DMGlobalToLocalEnd', [ + dmda, solver_objs['X_global'], 'INSERT_VALUES', solver_objs['X_local'] + ]) + + dm_get_local_yvec = petsc_call( + 'DMGetLocalVector', [dmda, Byref(solver_objs['Y_local'])] + ) + + vec_get_array_y = petsc_call( + 'VecGetArray', [solver_objs['Y_local'], Byref(y_formfunc._C_symbol)] + ) + + vec_get_array_x = petsc_call( + 'VecGetArray', [solver_objs['X_local'], Byref(x_formfunc._C_symbol)] + ) + + dm_get_local_info = petsc_call( + 'DMDAGetLocalInfo', [dmda, Byref(dmda.info)] + ) + + vec_restore_array_y = petsc_call( + 'VecRestoreArray', [solver_objs['Y_local'], Byref(y_formfunc._C_symbol)] + ) + + vec_restore_array_x = petsc_call( + 'VecRestoreArray', [solver_objs['X_local'], Byref(x_formfunc._C_symbol)] + ) + + dm_local_to_global_begin = petsc_call('DMLocalToGlobalBegin', [ + dmda, solver_objs['Y_local'], 'INSERT_VALUES', solver_objs['Y_global'] + ]) + + dm_local_to_global_end = petsc_call('DMLocalToGlobalEnd', [ + dmda, solver_objs['Y_local'], 'INSERT_VALUES', solver_objs['Y_global'] + ]) + + body = body._rebuild( + body=body.body + + (vec_restore_array_y, + vec_restore_array_x, + dm_local_to_global_begin, + dm_local_to_global_end) + ) + + stacks = ( + snes_get_dm, + dm_get_app_context, + dm_get_local_xvec, + global_to_local_begin, + global_to_local_end, + dm_get_local_yvec, + vec_get_array_y, + vec_get_array_x, + dm_get_local_info + ) + + # Dereference function data in struct + dereference_funcs = [Dereference(i, struct) for i in + struct.fields if isinstance(i.function, AbstractFunction)] + + formfunc_body = CallableBody( + List(body=body), + init=(petsc_func_begin_user,), + stacks=stacks+tuple(dereference_funcs), + retstmt=(Call('PetscFunctionReturn', arguments=[0]),)) + + # Replace non-function data with pointer to data in struct + subs = {i._C_symbol: FieldFromPointer(i._C_symbol, struct) for i in struct.fields} + formfunc_body = Uxreplace(subs).visit(formfunc_body) + + self._struct_params.extend(struct.fields) + + return formfunc_body + + def make_formrhs(self, injectsolve, objs, solver_objs): + # Compile formrhs `eqns` into an IET via recursive compilation + irs_formrhs, _ = self.rcompile(injectsolve.expr.rhs.formrhs, + options={'mpi': False}, sregistry=SymbolRegistry()) + body_formrhs = self.create_formrhs_body(injectsolve, + List(body=irs_formrhs.uiet.body), + solver_objs, objs) + + formrhs_callback = PETScCallable( + self.sregistry.make_name(prefix='FormRHS_'), body_formrhs, retval=objs['err'], + parameters=( + solver_objs['snes'], solver_objs['b_local'] + ) + ) + + return formrhs_callback + + def create_formrhs_body(self, injectsolve, body, solver_objs, objs): + linsolveexpr = injectsolve.expr.rhs + + dmda = objs['da_so_%s' % linsolveexpr.target.space_order] + + snes_get_dm = petsc_call('SNESGetDM', [solver_objs['snes'], Byref(dmda)]) + + b_arr = linsolveexpr.arrays['b_tmp'] + + vec_get_array = petsc_call( + 'VecGetArray', [solver_objs['b_local'], Byref(b_arr._C_symbol)] + ) + + dm_get_local_info = petsc_call( + 'DMDAGetLocalInfo', [dmda, Byref(dmda.info)] + ) + + body = uxreplace_time(body, solver_objs) + + struct = build_local_struct(body, 'formrhs', liveness='eager') + + dm_get_app_context = petsc_call( + 'DMGetApplicationContext', [dmda, Byref(struct._C_symbol)] + ) + + vec_restore_array = petsc_call( + 'VecRestoreArray', [solver_objs['b_local'], Byref(b_arr._C_symbol)] + ) + + body = body._rebuild(body=body.body + (vec_restore_array,)) + + stacks = ( + snes_get_dm, + dm_get_app_context, + vec_get_array, + dm_get_local_info + ) + + # Dereference function data in struct + dereference_funcs = [Dereference(i, struct) for i in + struct.fields if isinstance(i.function, AbstractFunction)] + + formrhs_body = CallableBody( + List(body=[body]), + init=(petsc_func_begin_user,), + stacks=stacks+tuple(dereference_funcs), + retstmt=(Call('PetscFunctionReturn', arguments=[0]),) + ) + + # Replace non-function data with pointer to data in struct + subs = {i._C_symbol: FieldFromPointer(i._C_symbol, struct) for + i in struct.fields if not isinstance(i.function, AbstractFunction)} + + formrhs_body = Uxreplace(subs).visit(formrhs_body) + + self._struct_params.extend(struct.fields) + + return formrhs_body + + def runsolve(self, solver_objs, objs, rhs_callback, injectsolve): + target = injectsolve.expr.rhs.target + + dmda = objs['da_so_%s' % target.space_order] + + rhs_call = petsc_call(rhs_callback.name, list(rhs_callback.parameters)) + + local_x = petsc_call('DMCreateLocalVector', + [dmda, Byref(solver_objs['x_local'])]) + + if any(i.is_Time for i in target.dimensions): + vec_replace_array = time_dep_replace( + injectsolve, solver_objs, objs, self.sregistry + ) + else: + field_from_ptr = FieldFromPointer(target._C_field_data, target._C_symbol) + vec_replace_array = (petsc_call( + 'VecReplaceArray', [solver_objs['x_local'], field_from_ptr] + ),) + + dm_local_to_global_x = petsc_call( + 'DMLocalToGlobal', [dmda, solver_objs['x_local'], 'INSERT_VALUES', + solver_objs['x_global']] + ) + + dm_local_to_global_b = petsc_call( + 'DMLocalToGlobal', [dmda, solver_objs['b_local'], 'INSERT_VALUES', + solver_objs['b_global']] + ) + + snes_solve = petsc_call('SNESSolve', [ + solver_objs['snes'], solver_objs['b_global'], solver_objs['x_global']] + ) + + dm_global_to_local_x = petsc_call('DMGlobalToLocal', [ + dmda, solver_objs['x_global'], 'INSERT_VALUES', solver_objs['x_local']] + ) + + return ( + rhs_call, + local_x + ) + vec_replace_array + ( + dm_local_to_global_x, + dm_local_to_global_b, + snes_solve, + dm_global_to_local_x, + BlankLine, + ) + + def make_main_struct(self, unique_dmdas, objs): + struct_main = petsc_struct('ctx', filter_ordered(self.struct_params)) + struct_callback = self.generate_struct_callback(struct_main, objs) + call_struct_callback = petsc_call(struct_callback.name, [Byref(struct_main)]) + calls_set_app_ctx = [ + petsc_call('DMSetApplicationContext', [i, Byref(struct_main)]) + for i in unique_dmdas + ] + calls = [call_struct_callback] + calls_set_app_ctx + + self._efuncs[struct_callback.name] = struct_callback + return struct_main, calls + + def generate_struct_callback(self, struct, objs): + body = [ + DummyExpr(FieldFromPointer(i._C_symbol, struct), i._C_symbol) + for i in struct.fields if i not in struct.time_dim_fields + ] + struct_callback_body = CallableBody( + List(body=body), init=tuple([petsc_func_begin_user]), + retstmt=tuple([Call('PetscFunctionReturn', arguments=[0])]) + ) + struct_callback = Callable( + 'PopulateMatContext', struct_callback_body, objs['err'], + parameters=[struct] + ) + return struct_callback + + +def build_local_struct(iet, name, liveness): + # Place all context data required by the shell routines into a struct + fields = [ + i.function for i in FindSymbols('basics').visit(iet) + if not isinstance(i.function, (PETScArray, Temp)) + and not (i.is_Dimension and not isinstance(i, (TimeDimension, ModuloDimension))) + ] + return petsc_struct(name, fields, liveness) + + +def time_dep_replace(injectsolve, solver_objs, objs, sregistry): + target = injectsolve.expr.lhs + target_time = [ + i for i, d in zip(target.indices, target.dimensions) if d.is_Time + ] + assert len(target_time) == 1 + target_time = target_time.pop() + + start_ptr = solver_objs['start_ptr'] + + vec_get_size = petsc_call( + 'VecGetSize', [solver_objs['x_local'], Byref(solver_objs['localsize'])] + ) + + field_from_ptr = FieldFromPointer( + target.function._C_field_data, target.function._C_symbol + ) + + expr = DummyExpr( + start_ptr, cast_mapper[(target.dtype, '*')](field_from_ptr) + + Mul(target_time, solver_objs['localsize']), init=True + ) + + vec_replace_array = petsc_call('VecReplaceArray', [solver_objs['x_local'], start_ptr]) + return (vec_get_size, expr, vec_replace_array) + + +def uxreplace_time(body, solver_objs): + # TODO: Potentially introduce a TimeIteration abstraction to simplify + # all the time processing that is done (searches, replacements, ...) + # "manually" via free functions + time_spacing = solver_objs['target'].grid.stepping_dim.spacing + true_dims = solver_objs['true_dims'] + + time_mapper = { + v: k.xreplace({time_spacing: 1, -time_spacing: -1}) + for k, v in solver_objs['time_mapper'].items() + } + subs = {symb: true_dims[time_mapper[symb]] for symb in time_mapper} + return Uxreplace(subs).visit(body) + + +Null = Macro('NULL') +void = 'void' + + +# TODO: Don't use c.Line here? +petsc_func_begin_user = c.Line('PetscFunctionBeginUser;') diff --git a/devito/petsc/iet/utils.py b/devito/petsc/iet/utils.py new file mode 100644 index 0000000000..a7855fbb36 --- /dev/null +++ b/devito/petsc/iet/utils.py @@ -0,0 +1,22 @@ +from devito.petsc.iet.nodes import InjectSolveDummy, PETScCall +from devito.ir.equations import OpInjectSolve + + +def petsc_call(specific_call, call_args): + return PETScCall('PetscCall', [PETScCall(specific_call, arguments=call_args)]) + + +def petsc_call_mpi(specific_call, call_args): + return PETScCall('PetscCallMPI', [PETScCall(specific_call, arguments=call_args)]) + + +def petsc_struct(name, fields, liveness='lazy'): + # TODO: Fix this circular import + from devito.petsc.types.object import PETScStruct + return PETScStruct(name=name, pname='MatContext', + fields=fields, liveness=liveness) + + +# Mapping special Eq operations to their corresponding IET Expression subclass types. +# These operations correspond to subclasses of Eq utilised within PETScSolve. +petsc_iet_mapper = {OpInjectSolve: InjectSolveDummy} diff --git a/devito/petsc/solve.py b/devito/petsc/solve.py new file mode 100644 index 0000000000..17c03ae435 --- /dev/null +++ b/devito/petsc/solve.py @@ -0,0 +1,181 @@ +from functools import singledispatch + +import sympy + +from devito.finite_differences.differentiable import Mul +from devito.finite_differences.derivative import Derivative +from devito.types import Eq, Symbol, SteppingDimension +from devito.types.equation import InjectSolveEq +from devito.operations.solve import eval_time_derivatives +from devito.symbolics import retrieve_functions +from devito.tools import as_tuple +from devito.petsc.types import LinearSolveExpr, PETScArray + + +__all__ = ['PETScSolve'] + + +def PETScSolve(eqns, target, solver_parameters=None, **kwargs): + prefixes = ['y_matvec', 'x_matvec', 'y_formfunc', 'x_formfunc', 'b_tmp'] + + arrays = { + p: PETScArray(name='%s_%s' % (p, target.name), + dtype=target.dtype, + dimensions=target.space_dimensions, + shape=target.grid.shape, + liveness='eager', + halo=[target.halo[d] for d in target.space_dimensions], + space_order=target.space_order) + for p in prefixes + } + + matvecs = [] + formfuncs = [] + formrhs = [] + + eqns = as_tuple(eqns) + funcs = retrieve_functions(eqns) + time_mapper = generate_time_mapper(funcs) + + for eq in eqns: + b, F_target = separate_eqn(eq, target) + b, F_target = b.subs(time_mapper), F_target.subs(time_mapper) + + # TODO: Current assumption is that problem is linear and user has not provided + # a jacobian. Hence, we can use F_target to form the jac-vec product + matvecs.append(Eq( + arrays['y_matvec'], + F_target.subs({target: arrays['x_matvec']}), + subdomain=eq.subdomain + )) + + formfuncs.append(Eq( + arrays['y_formfunc'], + F_target.subs({target: arrays['x_formfunc']}), + subdomain=eq.subdomain + )) + + formrhs.append(Eq( + arrays['b_tmp'], + b, + subdomain=eq.subdomain + )) + + # Placeholder equation for inserting calls to the solver and generating + # correct time loop etc + inject_solve = InjectSolveEq(target, LinearSolveExpr( + expr=tuple(funcs), + target=target, + solver_parameters=solver_parameters, + matvecs=matvecs, + formfuncs=formfuncs, + formrhs=formrhs, + arrays=arrays, + time_mapper=time_mapper, + ), subdomain=eq.subdomain) + + return [inject_solve] + + +def separate_eqn(eqn, target): + """ + Separate the equation into two separate expressions, + where F(target) = b. + """ + zeroed_eqn = Eq(eqn.lhs - eqn.rhs, 0) + tmp = eval_time_derivatives(zeroed_eqn.lhs) + b, F_target = remove_target(tmp, target) + return -b, F_target + + +@singledispatch +def remove_target(expr, target): + return (0, expr) if expr == target else (expr, 0) + + +@remove_target.register(sympy.Add) +def _(expr, target): + if not expr.has(target): + return (expr, 0) + + args_b, args_F = zip(*(remove_target(a, target) for a in expr.args)) + return (expr.func(*args_b, evaluate=False), expr.func(*args_F, evaluate=False)) + + +@remove_target.register(Mul) +def _(expr, target): + if not expr.has(target): + return (expr, 0) + + args_b, args_F = zip(*[remove_target(a, target) if a.has(target) + else (a, a) for a in expr.args]) + return (expr.func(*args_b, evaluate=False), expr.func(*args_F, evaluate=False)) + + +@remove_target.register(Derivative) +def _(expr, target): + return (0, expr) if expr.has(target) else (expr, 0) + + +@singledispatch +def centre_stencil(expr, target): + """ + Extract the centre stencil from an expression. Its coefficient is what + would appear on the diagonal of the matrix system if the matrix were + formed explicitly. + """ + return expr if expr == target else 0 + + +@centre_stencil.register(sympy.Add) +def _(expr, target): + if not expr.has(target): + return 0 + + args = [centre_stencil(a, target) for a in expr.args] + return expr.func(*args, evaluate=False) + + +@centre_stencil.register(Mul) +def _(expr, target): + if not expr.has(target): + return 0 + + args = [] + for a in expr.args: + if not a.has(target): + args.append(a) + else: + args.append(centre_stencil(a, target)) + + return expr.func(*args, evaluate=False) + + +@centre_stencil.register(Derivative) +def _(expr, target): + if not expr.has(target): + return 0 + args = [centre_stencil(a, target) for a in expr.evaluate.args] + return expr.evaluate.func(*args) + + +def generate_time_mapper(funcs): + """ + Replace time indices with `Symbols` in equations used within + PETSc callback functions. These symbols are Uxreplaced at the IET + level to align with the `TimeDimension` and `ModuloDimension` objects + present in the inital lowering. + NOTE: All functions used in PETSc callback functions are attached to + the `LinearSolveExpr` object, which is passed through the initial lowering + (and subsequently dropped and replaced with calls to run the solver). + Therefore, the appropriate time loop will always be correctly generated inside + the main kernel. + """ + time_indices = list({ + i if isinstance(d, SteppingDimension) else d + for f in funcs + for i, d in zip(f.indices, f.dimensions) + if d.is_Time + }) + tau_symbs = [Symbol('tau%d' % i) for i in range(len(time_indices))] + return {time: tau for time, tau in zip(time_indices, tau_symbs)} diff --git a/devito/petsc/types/__init__.py b/devito/petsc/types/__init__.py new file mode 100644 index 0000000000..ebcceb8d45 --- /dev/null +++ b/devito/petsc/types/__init__.py @@ -0,0 +1,3 @@ +from .array import * # noqa +from .types import * # noqa +from .object import * # noqa diff --git a/devito/petsc/types/array.py b/devito/petsc/types/array.py new file mode 100644 index 0000000000..a150ea3247 --- /dev/null +++ b/devito/petsc/types/array.py @@ -0,0 +1,117 @@ +from functools import cached_property +import numpy as np +from ctypes import POINTER + +from devito.types.utils import DimensionTuple +from devito.types.array import ArrayBasic +from devito.finite_differences import Differentiable +from devito.types.basic import AbstractFunction +from devito.finite_differences.tools import fd_weights_registry +from devito.tools import dtype_to_ctype +from devito.symbolics import FieldFromComposite + +from .object import DM + + +class PETScArray(ArrayBasic, Differentiable): + """ + PETScArrays are generated by the compiler only and represent + a customised variant of ArrayBasic. + Differentiable enables compatability with standard Function objects, + allowing for the use of the `subs` method. + TODO: Potentially re-evaluate and separate into PETScFunction(Differentiable) + and then PETScArray(ArrayBasic). + """ + + _data_alignment = False + + # Default method for the finite difference approximation weights computation. + _default_fd = 'taylor' + + __rkwargs__ = (AbstractFunction.__rkwargs__ + + ('dimensions', 'shape', 'liveness', 'coefficients', + 'space_order')) + + def __init_finalize__(self, *args, **kwargs): + + super().__init_finalize__(*args, **kwargs) + + # Symbolic (finite difference) coefficients + self._coefficients = kwargs.get('coefficients', self._default_fd) + if self._coefficients not in fd_weights_registry: + raise ValueError("coefficients must be one of %s" + " not %s" % (str(fd_weights_registry), self._coefficients)) + self._shape = kwargs.get('shape') + self._space_order = kwargs.get('space_order', 1) + + @classmethod + def __dtype_setup__(cls, **kwargs): + return kwargs.get('dtype', np.float32) + + @property + def coefficients(self): + """Form of the coefficients of the function.""" + return self._coefficients + + @property + def shape(self): + return self._shape + + @property + def space_order(self): + return self._space_order + + @cached_property + def _shape_with_inhalo(self): + """ + Shape of the domain+inhalo region. The inhalo region comprises the + outhalo as well as any additional "ghost" layers for MPI halo + exchanges. Data in the inhalo region are exchanged when running + Operators to maintain consistent values as in sequential runs. + + Notes + ----- + Typically, this property won't be used in user code, but it may come + in handy for testing or debugging + """ + return tuple(j + i + k for i, (j, k) in zip(self.shape, self._halo)) + + @cached_property + def shape_allocated(self): + """ + Shape of the allocated data of the Function type object from which + this PETScArray was derived. It includes the domain and inhalo regions, + as well as any additional padding surrounding the halo. + + Notes + ----- + In an MPI context, this is the *local* with_halo region shape. + """ + return DimensionTuple(*[j + i + k for i, (j, k) in zip(self._shape_with_inhalo, + self._padding)], + getters=self.dimensions) + + @cached_property + def _C_ctype(self): + # NOTE: Reverting to using float/double instead of PetscScalar for + # simplicity when opt='advanced'. Otherwise, Temp objects must also + # be converted to PetscScalar. Additional tests are needed to + # ensure this approach is fine. Previously, issues arose from + # mismatches between precision of Function objects in Devito and the + # precision of the PETSc configuration. + # TODO: Use cat $PETSC_DIR/$PETSC_ARCH/lib/petsc/conf/petscvariables + # | grep -E "PETSC_(SCALAR|PRECISION)" to determine the precision of + # the user's PETSc configuration. + return POINTER(dtype_to_ctype(self.dtype)) + + @property + def symbolic_shape(self): + field_from_composites = [ + FieldFromComposite('g%sm' % d.name, self.dmda.info) for d in self.dimensions] + # Reverse it since DMDA is setup backwards to Devito dimensions. + return DimensionTuple(*field_from_composites[::-1], getters=self.dimensions) + + @cached_property + def dmda(self): + name = 'da_so_%s' % self.space_order + return DM(name=name, liveness='eager', stencil_width=self.space_order) diff --git a/devito/petsc/types/object.py b/devito/petsc/types/object.py new file mode 100644 index 0000000000..9f8cbe4cbb --- /dev/null +++ b/devito/petsc/types/object.py @@ -0,0 +1,177 @@ +from ctypes import POINTER + +from devito.tools import CustomDtype, dtype_to_cstr +from devito.types import LocalObject, CCompositeObject, ModuloDimension, TimeDimension +from devito.symbolics import Byref + +from devito.petsc.iet.utils import petsc_call + + +class DM(LocalObject): + """ + PETSc Data Management object (DM). + """ + dtype = CustomDtype('DM') + + def __init__(self, *args, stencil_width=None, **kwargs): + super().__init__(*args, **kwargs) + self._stencil_width = stencil_width + + @property + def stencil_width(self): + return self._stencil_width + + @property + def info(self): + return DMDALocalInfo(name='%s_info' % self.name, liveness='eager') + + @property + def _C_free(self): + return petsc_call('DMDestroy', [Byref(self.function)]) + + @property + def _C_free_priority(self): + return 3 + + +class Mat(LocalObject): + """ + PETSc Matrix object (Mat). + """ + dtype = CustomDtype('Mat') + + @property + def _C_free(self): + return petsc_call('MatDestroy', [Byref(self.function)]) + + @property + def _C_free_priority(self): + return 1 + + +class LocalVec(LocalObject): + """ + PETSc Vector object (Vec). + """ + dtype = CustomDtype('Vec') + + +class GlobalVec(LocalObject): + """ + PETSc Vector object (Vec). + """ + dtype = CustomDtype('Vec') + + @property + def _C_free(self): + return petsc_call('VecDestroy', [Byref(self.function)]) + + @property + def _C_free_priority(self): + return 0 + + +class PetscMPIInt(LocalObject): + """ + PETSc datatype used to represent `int` parameters + to MPI functions. + """ + dtype = CustomDtype('PetscMPIInt') + + +class PetscInt(LocalObject): + """ + PETSc datatype used to represent `int` parameters + to PETSc functions. + """ + dtype = CustomDtype('PetscInt') + + +class KSP(LocalObject): + """ + PETSc KSP : Linear Systems Solvers. + Manages Krylov Methods. + """ + dtype = CustomDtype('KSP') + + +class SNES(LocalObject): + """ + PETSc SNES : Non-Linear Systems Solvers. + """ + dtype = CustomDtype('SNES') + + @property + def _C_free(self): + return petsc_call('SNESDestroy', [Byref(self.function)]) + + @property + def _C_free_priority(self): + return 2 + + +class PC(LocalObject): + """ + PETSc object that manages all preconditioners (PC). + """ + dtype = CustomDtype('PC') + + +class KSPConvergedReason(LocalObject): + """ + PETSc object - reason a Krylov method was determined + to have converged or diverged. + """ + dtype = CustomDtype('KSPConvergedReason') + + +class DMDALocalInfo(LocalObject): + """ + PETSc object - C struct containing information + about the local grid. + """ + dtype = CustomDtype('DMDALocalInfo') + + +class PetscErrorCode(LocalObject): + """ + PETSc datatype used to return PETSc error codes. + https://petsc.org/release/manualpages/Sys/PetscErrorCode/ + """ + dtype = CustomDtype('PetscErrorCode') + + +class DummyArg(LocalObject): + dtype = CustomDtype('void', modifier='*') + + +class PETScStruct(CCompositeObject): + + __rargs__ = ('name', 'pname', 'fields') + + def __init__(self, name, pname, fields, liveness='lazy'): + pfields = [(i._C_name, i._C_ctype) for i in fields] + super().__init__(name, pname, pfields, liveness) + self._fields = fields + + @property + def fields(self): + return self._fields + + @property + def time_dim_fields(self): + return [f for f in self.fields + if isinstance(f, (ModuloDimension, TimeDimension))] + + @property + def _C_ctype(self): + return POINTER(self.dtype) if self.liveness == \ + 'eager' else self.dtype + + _C_modifier = ' *' + + +class StartPtr(LocalObject): + def __init__(self, name, dtype): + super().__init__(name=name) + self.dtype = CustomDtype(dtype_to_cstr(dtype), modifier=' *') diff --git a/devito/petsc/types/types.py b/devito/petsc/types/types.py new file mode 100644 index 0000000000..eda2fa40d4 --- /dev/null +++ b/devito/petsc/types/types.py @@ -0,0 +1,96 @@ +import sympy + +from devito.tools import Reconstructable, sympy_mutex + + +class LinearSolveExpr(sympy.Function, Reconstructable): + + __rargs__ = ('expr',) + __rkwargs__ = ('target', 'solver_parameters', 'matvecs', + 'formfuncs', 'formrhs', 'arrays', 'time_mapper') + + defaults = { + 'ksp_type': 'gmres', + 'pc_type': 'jacobi', + 'ksp_rtol': 1e-7, # Relative tolerance + 'ksp_atol': 1e-50, # Absolute tolerance + 'ksp_divtol': 1e4, # Divergence tolerance + 'ksp_max_it': 10000 # Maximum iterations + } + + def __new__(cls, expr, target=None, solver_parameters=None, + matvecs=None, formfuncs=None, formrhs=None, + arrays=None, time_mapper=None, **kwargs): + + if solver_parameters is None: + solver_parameters = cls.defaults + else: + for key, val in cls.defaults.items(): + solver_parameters[key] = solver_parameters.get(key, val) + + with sympy_mutex: + obj = sympy.Function.__new__(cls, expr) + + obj._expr = expr + obj._target = target + obj._solver_parameters = solver_parameters + obj._matvecs = matvecs + obj._formfuncs = formfuncs + obj._formrhs = formrhs + obj._arrays = arrays + obj._time_mapper = time_mapper + return obj + + def __repr__(self): + return "%s(%s)" % (self.__class__.__name__, self.expr) + + __str__ = __repr__ + + def _sympystr(self, printer): + return str(self) + + def __hash__(self): + return hash(self.expr) + + def __eq__(self, other): + return (isinstance(other, LinearSolveExpr) and + self.expr == other.expr and + self.target == other.target) + + @property + def expr(self): + return self._expr + + @property + def target(self): + return self._target + + @property + def solver_parameters(self): + return self._solver_parameters + + @property + def matvecs(self): + return self._matvecs + + @property + def formfuncs(self): + return self._formfuncs + + @property + def formrhs(self): + return self._formrhs + + @property + def arrays(self): + return self._arrays + + @property + def time_mapper(self): + return self._time_mapper + + @classmethod + def eval(cls, *args): + return None + + func = Reconstructable._rebuild diff --git a/devito/petsc/utils.py b/devito/petsc/utils.py new file mode 100644 index 0000000000..d898db23cb --- /dev/null +++ b/devito/petsc/utils.py @@ -0,0 +1,53 @@ +import os + +from devito.tools import memoized_func + + +solver_mapper = { + 'gmres': 'KSPGMRES', + 'jacobi': 'PCJACOBI', + None: 'PCNONE' +} + + +@memoized_func +def get_petsc_dir(): + # *** First try: via commonly used environment variables + for i in ['PETSC_DIR']: + petsc_dir = os.environ.get(i) + if petsc_dir: + return petsc_dir + # TODO: Raise error if PETSC_DIR is not set + return None + + +@memoized_func +def get_petsc_arch(): + # *** First try: via commonly used environment variables + for i in ['PETSC_ARCH']: + petsc_arch = os.environ.get(i) + if petsc_arch: + return petsc_arch + # TODO: Raise error if PETSC_ARCH is not set + return None + + +def core_metadata(): + petsc_dir = get_petsc_dir() + petsc_arch = get_petsc_arch() + + # Include directories + global_include = os.path.join(petsc_dir, 'include') + config_specific_include = os.path.join(petsc_dir, f'{petsc_arch}', 'include') + include_dirs = (global_include, config_specific_include) + + # Lib directories + lib_dir = os.path.join(petsc_dir, f'{petsc_arch}', 'lib') + + return { + 'includes': ('petscksp.h', 'petscsnes.h', 'petscdmda.h'), + 'include_dirs': include_dirs, + 'libs': ('petsc'), + 'lib_dirs': lib_dir, + 'ldflags': ('-Wl,-rpath,%s' % lib_dir) + } diff --git a/devito/types/basic.py b/devito/types/basic.py index 3d3241f27d..d4369b8f47 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -14,7 +14,7 @@ from devito.data import default_allocator from devito.parameters import configuration from devito.tools import (Pickable, as_tuple, ctypes_to_cstr, dtype_to_ctype, - frozendict, memoized_meth, sympy_mutex) + frozendict, memoized_meth, sympy_mutex, CustomDtype) from devito.types.args import ArgProvider from devito.types.caching import Cached, Uncached from devito.types.lazy import Evaluable @@ -84,6 +84,9 @@ def _C_typedata(self): The type of the object in the generated code as a `str`. """ _type = self._C_ctype + if isinstance(_type, CustomDtype): + return ctypes_to_cstr(_type) + while issubclass(_type, _Pointer): _type = _type._type_ diff --git a/devito/types/equation.py b/devito/types/equation.py index 662cdd0d34..8b3ead9873 100644 --- a/devito/types/equation.py +++ b/devito/types/equation.py @@ -218,3 +218,7 @@ class ReduceMax(Reduction): class ReduceMin(Reduction): pass + + +class InjectSolveEq(Eq): + pass diff --git a/devito/types/object.py b/devito/types/object.py index cba54b0add..aa738bd19b 100644 --- a/devito/types/object.py +++ b/devito/types/object.py @@ -1,5 +1,4 @@ from ctypes import byref - import sympy from devito.tools import Pickable, as_tuple, sympy_mutex @@ -8,7 +7,8 @@ from devito.types.basic import Basic, LocalType from devito.types.utils import CtypesFactory -__all__ = ['Object', 'LocalObject', 'CompositeObject'] + +__all__ = ['Object', 'LocalObject', 'CompositeObject', 'CCompositeObject'] class AbstractObject(Basic, sympy.Basic, Pickable): @@ -138,6 +138,7 @@ def __init__(self, name, pname, pfields, value=None): dtype = CtypesFactory.generate(pname, pfields) value = self.__value_setup__(dtype, value) super().__init__(name, dtype, value) + self._pname = pname def __value_setup__(self, dtype, value): return value or byref(dtype._type_()) @@ -148,7 +149,7 @@ def pfields(self): @property def pname(self): - return self.dtype._type_.__name__ + return self._pname @property def fields(self): @@ -231,6 +232,28 @@ def _C_free(self): """ return None + @property + def _C_free_priority(self): + return float('inf') + @property def _mem_global(self): return self._is_global + + +class CCompositeObject(CompositeObject, LocalType): + + """ + Object with composite type (e.g., a C struct) defined in C. + """ + + __rargs__ = ('name', 'pname', 'pfields') + + def __init__(self, name, pname, pfields, liveness='lazy'): + super().__init__(name, pname, pfields) + assert liveness in ['eager', 'lazy'] + self._liveness = liveness + + @property + def dtype(self): + return self._dtype._type_ diff --git a/docker/Dockerfile.cpu b/docker/Dockerfile.cpu index bed0bbad24..cae16941ac 100644 --- a/docker/Dockerfile.cpu +++ b/docker/Dockerfile.cpu @@ -18,7 +18,8 @@ RUN apt-get update && \ # Install for basic base not containing it RUN apt-get install -y vim wget git flex libnuma-dev tmux \ numactl hwloc curl \ - autoconf libtool build-essential procps + autoconf libtool build-essential procps \ + gfortran pkgconf libopenblas-serial-dev # Install tmpi RUN curl https://raw.githubusercontent.com/Azrael3000/tmpi/master/tmpi -o /usr/local/bin/tmpi diff --git a/docker/Dockerfile.devito b/docker/Dockerfile.devito index 99b21c87fb..1e6edbfc28 100644 --- a/docker/Dockerfile.devito +++ b/docker/Dockerfile.devito @@ -4,8 +4,26 @@ # Base image with compilers ARG base=devitocodes/bases:cpu-gcc +ARG petscinstall="" -FROM $base as builder +FROM $base as copybase + +################## Install PETSc ############################################ +FROM copybase as petsccopybase + +RUN apt-get update && apt-get install -y git && \ + python3 -m venv /venv && \ + /venv/bin/pip install --no-cache-dir --upgrade pip && \ + /venv/bin/pip install --no-cache-dir --no-binary numpy numpy && \ + mkdir -p /opt/petsc && \ + cd /opt/petsc && \ + git clone -b release https://gitlab.com/petsc/petsc.git petsc && \ + cd petsc && \ + ./configure --with-fortran-bindings=0 --with-mpi-dir=/opt/openmpi --with-openblas-include=$(pkg-config --variable=includedir openblas) --with-openblas-lib=$(pkg-config --variable=libdir openblas)/libopenblas.so PETSC_ARCH=devito_build && \ + make all + +ARG petscinstall="" +FROM ${petscinstall}copybase as builder # User/Group Ids ARG USER_ID=1000 @@ -57,6 +75,9 @@ ARG GROUP_ID=1000 ENV HOME=/app ENV APP_HOME=/app +ENV PETSC_ARCH="devito_build" +ENV PETSC_DIR="/opt/petsc/petsc" + # Create the home directory for the new app user. # Create an app user so our program doesn't run as root. # Chown all the files to the app user. @@ -89,4 +110,3 @@ USER app EXPOSE 8888 ENTRYPOINT ["/docker-entrypoint.sh"] CMD ["/jupyter"] - diff --git a/tests/test_iet.py b/tests/test_iet.py index ff963d5518..c1c78eb7ea 100644 --- a/tests/test_iet.py +++ b/tests/test_iet.py @@ -10,7 +10,8 @@ from devito.ir.iet import (Call, Callable, Conditional, DeviceCall, DummyExpr, Iteration, List, KernelLaunch, Lambda, ElementalFunction, CGen, FindSymbols, filter_iterations, make_efunc, - retrieve_iteration_tree, Transformer) + retrieve_iteration_tree, Transformer, Callback, + Definition, FindNodes) from devito.ir import SymbolRegistry from devito.passes.iet.engine import Graph from devito.passes.iet.languages.C import CDataManager @@ -128,6 +129,33 @@ def test_find_symbols_nested(mode, expected): assert [f.name for f in found] == eval(expected) +def test_callback_cgen(): + + class FunctionPtr(Callback): + @property + def callback_form(self): + param_types = ', '.join([str(t) for t in + self.param_types]) + return "(%s (*)(%s))%s" % (self.retval, param_types, self.name) + + a = Symbol('a') + b = Symbol('b') + foo0 = Callable('foo0', Definition(a), 'void', parameters=[b]) + foo0_arg = FunctionPtr(foo0.name, foo0.retval, 'int') + code0 = CGen().visit(foo0_arg) + assert str(code0) == '(void (*)(int))foo0' + + # Test nested calls with a Callback as an argument. + call = Call('foo2', [ + Call('foo1', [foo0_arg]) + ]) + code1 = CGen().visit(call) + assert str(code1) == 'foo2(foo1((void (*)(int))foo0));' + + callees = FindNodes(Call).visit(call) + assert len(callees) == 3 + + def test_list_denesting(): l0 = List(header=cgen.Line('a'), body=List(header=cgen.Line('b'))) l1 = l0._rebuild(body=List(header=cgen.Line('c'))) diff --git a/tests/test_petsc.py b/tests/test_petsc.py new file mode 100644 index 0000000000..f5616d4f36 --- /dev/null +++ b/tests/test_petsc.py @@ -0,0 +1,653 @@ +import numpy as np +import os +import pytest + +from conftest import skipif +from devito import Grid, Function, TimeFunction, Eq, Operator, switchconfig +from devito.ir.iet import (Call, ElementalFunction, Definition, DummyExpr, + FindNodes, PointerCast, retrieve_iteration_tree) +from devito.types import Constant, CCompositeObject +from devito.passes.iet.languages.C import CDataManager +from devito.petsc.types import (DM, Mat, LocalVec, PetscMPIInt, KSP, + PC, KSPConvergedReason, PETScArray, + LinearSolveExpr) +from devito.petsc.solve import PETScSolve, separate_eqn, centre_stencil +from devito.petsc.iet.nodes import Expression + + +@skipif('petsc') +def test_petsc_local_object(): + """ + Test C++ support for PETSc LocalObjects. + """ + lo0 = DM('da', stencil_width=1) + lo1 = Mat('A') + lo2 = LocalVec('x') + lo3 = PetscMPIInt('size') + lo4 = KSP('ksp') + lo5 = PC('pc') + lo6 = KSPConvergedReason('reason') + + iet = Call('foo', [lo0, lo1, lo2, lo3, lo4, lo5, lo6]) + iet = ElementalFunction('foo', iet, parameters=()) + + dm = CDataManager(sregistry=None) + iet = CDataManager.place_definitions.__wrapped__(dm, iet)[0] + + assert 'DM da;' in str(iet) + assert 'Mat A;' in str(iet) + assert 'Vec x;' in str(iet) + assert 'PetscMPIInt size;' in str(iet) + assert 'KSP ksp;' in str(iet) + assert 'PC pc;' in str(iet) + assert 'KSPConvergedReason reason;' in str(iet) + + +@skipif('petsc') +def test_petsc_functions(): + """ + Test C++ support for PETScArrays. + """ + grid = Grid((2, 2)) + x, y = grid.dimensions + + ptr0 = PETScArray(name='ptr0', dimensions=grid.dimensions, dtype=np.float32) + ptr1 = PETScArray(name='ptr1', dimensions=grid.dimensions, dtype=np.float32, + is_const=True) + ptr2 = PETScArray(name='ptr2', dimensions=grid.dimensions, dtype=np.float64, + is_const=True) + ptr3 = PETScArray(name='ptr3', dimensions=grid.dimensions, dtype=np.int32) + ptr4 = PETScArray(name='ptr4', dimensions=grid.dimensions, dtype=np.int64, + is_const=True) + + defn0 = Definition(ptr0) + defn1 = Definition(ptr1) + defn2 = Definition(ptr2) + defn3 = Definition(ptr3) + defn4 = Definition(ptr4) + + expr = DummyExpr(ptr0.indexed[x, y], ptr1.indexed[x, y] + 1) + + assert str(defn0) == 'float *restrict ptr0_vec;' + assert str(defn1) == 'const float *restrict ptr1_vec;' + assert str(defn2) == 'const double *restrict ptr2_vec;' + assert str(defn3) == 'int *restrict ptr3_vec;' + assert str(defn4) == 'const long *restrict ptr4_vec;' + assert str(expr) == 'ptr0[x][y] = ptr1[x][y] + 1;' + + +@skipif('petsc') +def test_petsc_subs(): + """ + Test support for PETScArrays in substitutions. + """ + grid = Grid((2, 2)) + + f1 = Function(name='f1', grid=grid, space_order=2) + f2 = Function(name='f2', grid=grid, space_order=2) + + arr = PETScArray(name='arr', dimensions=f2.dimensions, dtype=f2.dtype) + + eqn = Eq(f1, f2.laplace) + eqn_subs = eqn.subs(f2, arr) + + assert str(eqn) == 'Eq(f1(x, y), Derivative(f2(x, y), (x, 2))' + \ + ' + Derivative(f2(x, y), (y, 2)))' + + assert str(eqn_subs) == 'Eq(f1(x, y), Derivative(arr(x, y), (x, 2))' + \ + ' + Derivative(arr(x, y), (y, 2)))' + + assert str(eqn_subs.rhs.evaluate) == '-2.0*arr(x, y)/h_x**2' + \ + ' + arr(x - h_x, y)/h_x**2 + arr(x + h_x, y)/h_x**2 - 2.0*arr(x, y)/h_y**2' + \ + ' + arr(x, y - h_y)/h_y**2 + arr(x, y + h_y)/h_y**2' + + +@skipif('petsc') +def test_petsc_solve(): + """ + Test PETScSolve. + """ + grid = Grid((2, 2)) + + f = Function(name='f', grid=grid, space_order=2) + g = Function(name='g', grid=grid, space_order=2) + + eqn = Eq(f.laplace, g) + + petsc = PETScSolve(eqn, f) + + with switchconfig(openmp=False): + op = Operator(petsc, opt='noop') + + callable_roots = [meta_call.root for meta_call in op._func_table.values()] + + matvec_callback = [root for root in callable_roots if root.name == 'MyMatShellMult_0'] + + formrhs_callback = [root for root in callable_roots if root.name == 'FormRHS_0'] + + action_expr = FindNodes(Expression).visit(matvec_callback[0]) + rhs_expr = FindNodes(Expression).visit(formrhs_callback[0]) + + assert str(action_expr[-1].expr.rhs) == \ + 'matvec->h_x**(-2)*x_matvec_f[x + 1, y + 2]' + \ + ' - 2.0*matvec->h_x**(-2)*x_matvec_f[x + 2, y + 2]' + \ + ' + matvec->h_x**(-2)*x_matvec_f[x + 3, y + 2]' + \ + ' + matvec->h_y**(-2)*x_matvec_f[x + 2, y + 1]' + \ + ' - 2.0*matvec->h_y**(-2)*x_matvec_f[x + 2, y + 2]' + \ + ' + matvec->h_y**(-2)*x_matvec_f[x + 2, y + 3]' + + assert str(rhs_expr[-1].expr.rhs) == 'g[x + 2, y + 2]' + + # Check the iteration bounds are correct. + assert op.arguments().get('x_m') == 0 + assert op.arguments().get('y_m') == 0 + assert op.arguments().get('y_M') == 1 + assert op.arguments().get('x_M') == 1 + + assert len(retrieve_iteration_tree(op)) == 0 + + # TODO: Remove pragmas from PETSc callback functions + assert len(matvec_callback[0].parameters) == 3 + + +@skipif('petsc') +def test_multiple_petsc_solves(): + """ + Test multiple PETScSolves. + """ + grid = Grid((2, 2)) + + f1 = Function(name='f1', grid=grid, space_order=2) + g1 = Function(name='g1', grid=grid, space_order=2) + + f2 = Function(name='f2', grid=grid, space_order=2) + g2 = Function(name='g2', grid=grid, space_order=2) + + eqn1 = Eq(f1.laplace, g1) + eqn2 = Eq(f2.laplace, g2) + + petsc1 = PETScSolve(eqn1, f1) + petsc2 = PETScSolve(eqn2, f2) + + with switchconfig(openmp=False): + op = Operator(petsc1+petsc2, opt='noop') + + callable_roots = [meta_call.root for meta_call in op._func_table.values()] + + # One FormRHS, one MatShellMult and one FormFunction per solve + # One PopulateMatContext for all solves + assert len(callable_roots) == 7 + + +@skipif('petsc') +def test_petsc_cast(): + """ + Test casting of PETScArray. + """ + g0 = Grid((2)) + g1 = Grid((2, 2)) + g2 = Grid((2, 2, 2)) + + arr0 = PETScArray(name='arr0', dimensions=g0.dimensions, shape=g0.shape) + arr1 = PETScArray(name='arr1', dimensions=g1.dimensions, shape=g1.shape) + arr2 = PETScArray(name='arr2', dimensions=g2.dimensions, shape=g2.shape) + + arr3 = PETScArray(name='arr3', dimensions=g1.dimensions, + shape=g1.shape, space_order=4) + + cast0 = PointerCast(arr0) + cast1 = PointerCast(arr1) + cast2 = PointerCast(arr2) + cast3 = PointerCast(arr3) + + assert str(cast0) == \ + 'float (*restrict arr0) = (float (*)) arr0_vec;' + assert str(cast1) == \ + 'float (*restrict arr1)[da_so_1_info.gxm] = ' + \ + '(float (*)[da_so_1_info.gxm]) arr1_vec;' + assert str(cast2) == \ + 'float (*restrict arr2)[da_so_1_info.gym][da_so_1_info.gxm] = ' + \ + '(float (*)[da_so_1_info.gym][da_so_1_info.gxm]) arr2_vec;' + assert str(cast3) == \ + 'float (*restrict arr3)[da_so_4_info.gxm] = ' + \ + '(float (*)[da_so_4_info.gxm]) arr3_vec;' + + +@skipif('petsc') +def test_LinearSolveExpr(): + + grid = Grid((2, 2)) + + f = Function(name='f', grid=grid, space_order=2) + g = Function(name='g', grid=grid, space_order=2) + + eqn = Eq(f, g.laplace) + + linsolveexpr = LinearSolveExpr(eqn.rhs, target=f) + + # Check the target + assert linsolveexpr.target == f + # Check the solver parameters + assert linsolveexpr.solver_parameters == \ + {'ksp_type': 'gmres', 'pc_type': 'jacobi', 'ksp_rtol': 1e-07, + 'ksp_atol': 1e-50, 'ksp_divtol': 10000.0, 'ksp_max_it': 10000} + + +@skipif('petsc') +def test_dmda_create(): + + grid1 = Grid((2)) + grid2 = Grid((2, 2)) + grid3 = Grid((4, 5, 6)) + + f1 = Function(name='f1', grid=grid1, space_order=2) + f2 = Function(name='f2', grid=grid2, space_order=4) + f3 = Function(name='f3', grid=grid3, space_order=6) + + eqn1 = Eq(f1.laplace, 10) + eqn2 = Eq(f2.laplace, 10) + eqn3 = Eq(f3.laplace, 10) + + petsc1 = PETScSolve(eqn1, f1) + petsc2 = PETScSolve(eqn2, f2) + petsc3 = PETScSolve(eqn3, f3) + + with switchconfig(openmp=False): + op1 = Operator(petsc1, opt='noop') + op2 = Operator(petsc2, opt='noop') + op3 = Operator(petsc3, opt='noop') + + assert 'PetscCall(DMDACreate1d(PETSC_COMM_SELF,DM_BOUNDARY_GHOSTED,' + \ + '2,1,2,NULL,&(da_so_2)));' in str(op1) + + assert 'PetscCall(DMDACreate2d(PETSC_COMM_SELF,DM_BOUNDARY_GHOSTED,' + \ + 'DM_BOUNDARY_GHOSTED,DMDA_STENCIL_BOX,2,2,1,1,1,4,NULL,NULL,&(da_so_4)));' \ + in str(op2) + + assert 'PetscCall(DMDACreate3d(PETSC_COMM_SELF,DM_BOUNDARY_GHOSTED,' + \ + 'DM_BOUNDARY_GHOSTED,DM_BOUNDARY_GHOSTED,DMDA_STENCIL_BOX,6,5,4' + \ + ',1,1,1,1,6,NULL,NULL,NULL,&(da_so_6)));' in str(op3) + + # Check unique DMDA is created per grid, per space_order + f4 = Function(name='f4', grid=grid2, space_order=6) + eqn4 = Eq(f4.laplace, 10) + petsc4 = PETScSolve(eqn4, f4) + with switchconfig(openmp=False): + op4 = Operator(petsc2+petsc2+petsc4, opt='noop') + assert str(op4).count('DMDACreate2d') == 2 + + +@skipif('petsc') +def test_cinterface_petsc_struct(): + + grid = Grid(shape=(11, 11)) + f = Function(name='f', grid=grid, space_order=2) + eq = Eq(f.laplace, 10) + petsc = PETScSolve(eq, f) + + name = "foo" + with switchconfig(openmp=False): + op = Operator(petsc, name=name) + + # Trigger the generation of a .c and a .h files + ccode, hcode = op.cinterface(force=True) + + dirname = op._compiler.get_jit_dir() + assert os.path.isfile(os.path.join(dirname, "%s.c" % name)) + assert os.path.isfile(os.path.join(dirname, "%s.h" % name)) + + ccode = str(ccode) + hcode = str(hcode) + + assert 'include "%s.h"' % name in ccode + + # The public `struct MatContext` only appears in the header file + assert 'struct MatContext\n{' not in ccode + assert 'struct MatContext\n{' in hcode + + +@skipif('petsc') +@pytest.mark.parametrize('eqn, target, expected', [ + ('Eq(f1.laplace, g1)', + 'f1', ('g1(x, y)', 'Derivative(f1(x, y), (x, 2)) + Derivative(f1(x, y), (y, 2))')), + ('Eq(g1, f1.laplace)', + 'f1', ('-g1(x, y)', '-Derivative(f1(x, y), (x, 2)) - Derivative(f1(x, y), (y, 2))')), + ('Eq(g1, f1.laplace)', 'g1', + ('Derivative(f1(x, y), (x, 2)) + Derivative(f1(x, y), (y, 2))', 'g1(x, y)')), + ('Eq(f1 + f1.laplace, g1)', 'f1', ('g1(x, y)', + 'f1(x, y) + Derivative(f1(x, y), (x, 2)) + Derivative(f1(x, y), (y, 2))')), + ('Eq(g1.dx + f1.dx, g1)', 'f1', + ('g1(x, y) - Derivative(g1(x, y), x)', 'Derivative(f1(x, y), x)')), + ('Eq(g1.dx + f1.dx, g1)', 'g1', + ('-Derivative(f1(x, y), x)', '-g1(x, y) + Derivative(g1(x, y), x)')), + ('Eq(f1 * g1.dx, g1)', 'g1', ('0', 'f1(x, y)*Derivative(g1(x, y), x) - g1(x, y)')), + ('Eq(f1 * g1.dx, g1)', 'f1', ('g1(x, y)', 'f1(x, y)*Derivative(g1(x, y), x)')), + ('Eq((f1 * g1.dx).dy, f1)', 'f1', + ('0', '-f1(x, y) + Derivative(f1(x, y)*Derivative(g1(x, y), x), y)')), + ('Eq((f1 * g1.dx).dy, f1)', 'g1', + ('f1(x, y)', 'Derivative(f1(x, y)*Derivative(g1(x, y), x), y)')), + ('Eq(f2.laplace, g2)', 'g2', + ('-Derivative(f2(t, x, y), (x, 2)) - Derivative(f2(t, x, y), (y, 2))', + '-g2(t, x, y)')), + ('Eq(f2.laplace, g2)', 'f2', ('g2(t, x, y)', + 'Derivative(f2(t, x, y), (x, 2)) + Derivative(f2(t, x, y), (y, 2))')), + ('Eq(f2.laplace, f2)', 'f2', ('0', + '-f2(t, x, y) + Derivative(f2(t, x, y), (x, 2)) + Derivative(f2(t, x, y), (y, 2))')), + ('Eq(f2*g2, f2)', 'f2', ('0', 'f2(t, x, y)*g2(t, x, y) - f2(t, x, y)')), + ('Eq(f2*g2, f2)', 'g2', ('f2(t, x, y)', 'f2(t, x, y)*g2(t, x, y)')), + ('Eq(g2*f2.laplace, f2)', 'g2', ('f2(t, x, y)', + '(Derivative(f2(t, x, y), (x, 2)) + Derivative(f2(t, x, y), (y, 2)))*g2(t, x, y)')), + ('Eq(f2.forward, f2)', 'f2.forward', ('f2(t, x, y)', 'f2(t + dt, x, y)')), + ('Eq(f2.forward, f2)', 'f2', ('-f2(t + dt, x, y)', '-f2(t, x, y)')), + ('Eq(f2.forward.laplace, f2)', 'f2.forward', ('f2(t, x, y)', + 'Derivative(f2(t + dt, x, y), (x, 2)) + Derivative(f2(t + dt, x, y), (y, 2))')), + ('Eq(f2.forward.laplace, f2)', 'f2', + ('-Derivative(f2(t + dt, x, y), (x, 2)) - Derivative(f2(t + dt, x, y), (y, 2))', + '-f2(t, x, y)')), + ('Eq(f2.laplace + f2.forward.laplace, g2)', 'f2.forward', + ('g2(t, x, y) - Derivative(f2(t, x, y), (x, 2)) - Derivative(f2(t, x, y), (y, 2))', + 'Derivative(f2(t + dt, x, y), (x, 2)) + Derivative(f2(t + dt, x, y), (y, 2))')), + ('Eq(g2.laplace, f2 + g2.forward)', 'g2.forward', + ('f2(t, x, y) - Derivative(g2(t, x, y), (x, 2)) - Derivative(g2(t, x, y), (y, 2))', + '-g2(t + dt, x, y)')) +]) +def test_separate_eqn(eqn, target, expected): + """ + Test the separate_eqn function. + + This function is called within PETScSolve to decompose the equation + into the form F(x) = b. This is necessary to utilise the SNES + interface in PETSc. + """ + grid = Grid((2, 2)) + + so = 2 + + f1 = Function(name='f1', grid=grid, space_order=so) # noqa + g1 = Function(name='g1', grid=grid, space_order=so) # noqa + + f2 = TimeFunction(name='f2', grid=grid, space_order=so) # noqa + g2 = TimeFunction(name='g2', grid=grid, space_order=so) # noqa + + b, F = separate_eqn(eval(eqn), eval(target)) + expected_b, expected_F = expected + + assert str(b) == expected_b + assert str(F) == expected_F + + +@skipif('petsc') +@pytest.mark.parametrize('expr, so, target, expected', [ + ('f1.laplace', 2, 'f1', '-2.0*f1(x, y)/h_y**2 - 2.0*f1(x, y)/h_x**2'), + ('f1 + f1.laplace', 2, 'f1', + 'f1(x, y) - 2.0*f1(x, y)/h_y**2 - 2.0*f1(x, y)/h_x**2'), + ('g1.dx + f1.dx', 2, 'f1', '-f1(x, y)/h_x'), + ('10 + f1.dx2', 2, 'g1', '0'), + ('(f1 * g1.dx).dy', 2, 'f1', + '(-1/h_y)*(-g1(x, y)/h_x + g1(x + h_x, y)/h_x)*f1(x, y)'), + ('(f1 * g1.dx).dy', 2, 'g1', '-(-1/h_y)*f1(x, y)*g1(x, y)/h_x'), + ('f2.laplace', 2, 'f2', '-2.0*f2(t, x, y)/h_y**2 - 2.0*f2(t, x, y)/h_x**2'), + ('f2*g2', 2, 'f2', 'f2(t, x, y)*g2(t, x, y)'), + ('g2*f2.laplace', 2, 'f2', + '(-2.0*f2(t, x, y)/h_y**2 - 2.0*f2(t, x, y)/h_x**2)*g2(t, x, y)'), + ('f2.forward', 2, 'f2.forward', 'f2(t + dt, x, y)'), + ('f2.forward.laplace', 2, 'f2.forward', + '-2.0*f2(t + dt, x, y)/h_y**2 - 2.0*f2(t + dt, x, y)/h_x**2'), + ('f2.laplace + f2.forward.laplace', 2, 'f2.forward', + '-2.0*f2(t + dt, x, y)/h_y**2 - 2.0*f2(t + dt, x, y)/h_x**2'), + ('f2.laplace + f2.forward.laplace', 2, + 'f2', '-2.0*f2(t, x, y)/h_y**2 - 2.0*f2(t, x, y)/h_x**2'), + ('f2.laplace', 4, 'f2', '-2.5*f2(t, x, y)/h_y**2 - 2.5*f2(t, x, y)/h_x**2'), + ('f2.laplace + f2.forward.laplace', 4, 'f2.forward', + '-2.5*f2(t + dt, x, y)/h_y**2 - 2.5*f2(t + dt, x, y)/h_x**2'), + ('f2.laplace + f2.forward.laplace', 4, 'f2', + '-2.5*f2(t, x, y)/h_y**2 - 2.5*f2(t, x, y)/h_x**2'), + ('f2.forward*f2.forward.laplace', 4, 'f2.forward', + '(-2.5*f2(t + dt, x, y)/h_y**2 - 2.5*f2(t + dt, x, y)/h_x**2)*f2(t + dt, x, y)') +]) +def test_centre_stencil(expr, so, target, expected): + """ + Test extraction of centre stencil from an equation. + """ + grid = Grid((2, 2)) + + f1 = Function(name='f1', grid=grid, space_order=so) # noqa + g1 = Function(name='g1', grid=grid, space_order=so) # noqa + + f2 = TimeFunction(name='f2', grid=grid, space_order=so) # noqa + g2 = TimeFunction(name='g2', grid=grid, space_order=so) # noqa + + centre = centre_stencil(eval(expr), eval(target)) + + assert str(centre) == expected + + +@skipif('petsc') +def test_callback_arguments(): + """ + Test the arguments of each callback function. + """ + grid = Grid((2, 2)) + + f1 = Function(name='f1', grid=grid, space_order=2) + g1 = Function(name='g1', grid=grid, space_order=2) + + eqn1 = Eq(f1.laplace, g1) + + petsc1 = PETScSolve(eqn1, f1) + + with switchconfig(openmp=False): + op = Operator(petsc1) + + mv = op._func_table['MyMatShellMult_0'].root + ff = op._func_table['FormFunction_0'].root + + assert len(mv.parameters) == 3 + assert len(ff.parameters) == 4 + + assert str(mv.parameters) == '(J_0, X_global_0, Y_global_0)' + assert str(ff.parameters) == '(snes_0, X_global_0, Y_global_0, dummy_0)' + + +@skipif('petsc') +def test_petsc_struct(): + + grid = Grid((2, 2)) + + f1 = Function(name='f1', grid=grid, space_order=2) + g1 = Function(name='g1', grid=grid, space_order=2) + + mu1 = Constant(name='mu1', value=2.0) + mu2 = Constant(name='mu2', value=2.0) + + eqn1 = Eq(f1.laplace, g1*mu1) + petsc1 = PETScSolve(eqn1, f1) + + eqn2 = Eq(f1, g1*mu2) + + with switchconfig(openmp=False): + op = Operator([eqn2] + petsc1) + + arguments = op.arguments() + + # Check mu1 and mu2 in arguments + assert 'mu1' in arguments + assert 'mu2' in arguments + + # Check mu1 and mu2 in op.parameters + assert mu1 in op.parameters + assert mu2 in op.parameters + + # Check PETSc struct not in op.parameters + assert all(not isinstance(i, CCompositeObject) for i in op.parameters) + + +@skipif('petsc') +@pytest.mark.parallel(mode=[2, 4, 8]) +def test_apply(mode): + + grid = Grid(shape=(13, 13), dtype=np.float64) + + pn = Function(name='pn', grid=grid, space_order=2, dtype=np.float64) + rhs = Function(name='rhs', grid=grid, space_order=2, dtype=np.float64) + mu = Constant(name='mu', value=2.0) + + eqn = Eq(pn.laplace*mu, rhs, subdomain=grid.interior) + + petsc = PETScSolve(eqn, pn) + + # Build the op + with switchconfig(openmp=False, mpi=True): + op = Operator(petsc) + + # Check the Operator runs without errors. Not verifying output for + # now. Need to consolidate BC implementation + op.apply() + + # Verify that users can override `mu` + mu_new = Constant(name='mu_new', value=4.0) + op.apply(mu=mu_new) + + +@skipif('petsc') +def test_petsc_frees(): + + grid = Grid((2, 2)) + + f = Function(name='f', grid=grid, space_order=2) + g = Function(name='g', grid=grid, space_order=2) + + eqn = Eq(f.laplace, g) + petsc = PETScSolve(eqn, f) + + with switchconfig(openmp=False): + op = Operator(petsc) + + frees = op.body.frees + + # Check the frees appear in the following order + assert str(frees[0]) == 'PetscCall(VecDestroy(&(b_global_0)));' + assert str(frees[1]) == 'PetscCall(VecDestroy(&(x_global_0)));' + assert str(frees[2]) == 'PetscCall(MatDestroy(&(J_0)));' + assert str(frees[3]) == 'PetscCall(SNESDestroy(&(snes_0)));' + assert str(frees[4]) == 'PetscCall(DMDestroy(&(da_so_2)));' + + +@skipif('petsc') +def test_calls_to_callbacks(): + + grid = Grid((2, 2)) + + f = Function(name='f', grid=grid, space_order=2) + g = Function(name='g', grid=grid, space_order=2) + + eqn = Eq(f.laplace, g) + petsc = PETScSolve(eqn, f) + + with switchconfig(openmp=False): + op = Operator(petsc) + + ccode = str(op.ccode) + + assert '(void (*)(void))MyMatShellMult_0' in ccode + assert 'PetscCall(SNESSetFunction(snes_0,NULL,FormFunction_0,NULL));' in ccode + + +@skipif('petsc') +def test_start_ptr(): + """ + Verify that a pointer to the start of the memory address is correctly + generated for TimeFunction objects. This pointer should indicate the + beginning of the multidimensional array that will be overwritten at + the current time step. + This functionality is crucial for VecReplaceArray operations, as it ensures + that the correct memory location is accessed and modified during each time step. + """ + grid = Grid((11, 11)) + u1 = TimeFunction(name='u1', grid=grid, space_order=2, dtype=np.float32) + eq1 = Eq(u1.dt, u1.laplace, subdomain=grid.interior) + petsc1 = PETScSolve(eq1, u1.forward) + + with switchconfig(openmp=False): + op1 = Operator(petsc1) + + # Verify the case with modulo time stepping + assert 'float * start_ptr_0 = t1*localsize_0 + (float*)(u1_vec->data);' in str(op1) + + # Verify the case with no modulo time stepping + u2 = TimeFunction(name='u2', grid=grid, space_order=2, dtype=np.float32, save=5) + eq2 = Eq(u2.dt, u2.laplace, subdomain=grid.interior) + petsc2 = PETScSolve(eq2, u2.forward) + + with switchconfig(openmp=False): + op2 = Operator(petsc2) + + assert 'float * start_ptr_0 = (time + 1)*localsize_0 + ' + \ + '(float*)(u2_vec->data);' in str(op2) + + +@skipif('petsc') +def test_time_loop(): + """ + Verify the following: + - Modulo dimensions are correctly assigned and updated in the PETSc struct + at each time step. + - Only assign/update the modulo dimensions required by any of the + PETSc callback functions. + """ + grid = Grid((11, 11)) + + # Modulo time stepping + u1 = TimeFunction(name='u1', grid=grid, space_order=2) + v1 = Function(name='v1', grid=grid, space_order=2) + eq1 = Eq(v1.laplace, u1) + petsc1 = PETScSolve(eq1, v1) + with switchconfig(openmp=False): + op1 = Operator(petsc1) + body1 = str(op1.body) + rhs1 = str(op1._func_table['FormRHS_0'].root.ccode) + + assert 'ctx.t0 = t0' in body1 + assert 'ctx.t1 = t1' not in body1 + assert 'formrhs->t0' in rhs1 + assert 'formrhs->t1' not in rhs1 + + # Non-modulo time stepping + u2 = TimeFunction(name='u2', grid=grid, space_order=2, save=5) + v2 = Function(name='v2', grid=grid, space_order=2, save=5) + eq2 = Eq(v2.laplace, u2) + petsc2 = PETScSolve(eq2, v2) + with switchconfig(openmp=False): + op2 = Operator(petsc2) + body2 = str(op2.body) + rhs2 = str(op2._func_table['FormRHS_0'].root.ccode) + + assert 'ctx.time = time' in body2 + assert 'formrhs->time' in rhs2 + + # Modulo time stepping with more than one time step + # used in one of the callback functions + eq3 = Eq(v1.laplace, u1 + u1.forward) + petsc3 = PETScSolve(eq3, v1) + with switchconfig(openmp=False): + op3 = Operator(petsc3) + body3 = str(op3.body) + rhs3 = str(op3._func_table['FormRHS_0'].root.ccode) + + assert 'ctx.t0 = t0' in body3 + assert 'ctx.t1 = t1' in body3 + assert 'formrhs->t0' in rhs3 + assert 'formrhs->t1' in rhs3 + + # Multiple petsc solves within the same time loop + v2 = Function(name='v2', grid=grid, space_order=2) + eq4 = Eq(v1.laplace, u1) + petsc4 = PETScSolve(eq4, v1) + eq5 = Eq(v2.laplace, u1) + petsc5 = PETScSolve(eq5, v2) + with switchconfig(openmp=False): + op4 = Operator(petsc4 + petsc5) + body4 = str(op4.body) + + assert 'ctx.t0 = t0' in body4 + assert body4.count('ctx.t0 = t0') == 1