diff --git a/firedrake/__init__.py b/firedrake/__init__.py index 1c18b04cd5..f2a5551824 100644 --- a/firedrake/__init__.py +++ b/firedrake/__init__.py @@ -102,7 +102,9 @@ def init_petsc(): NonNestedHierarchy, SemiCoarsenedExtrudedHierarchy, prolong, restrict, inject, TransferManager, OpenCascadeMeshHierarchy, AdaptiveMeshHierarchy, - AdaptiveTransferManager + AdaptiveTransferManager, + CoarsePatchTransferManager, + FinePatchTransferManager, ) from firedrake.norms import errornorm, norm # noqa: F401 from firedrake.nullspace import VectorSpaceBasis, MixedVectorSpaceBasis # noqa: F401 diff --git a/firedrake/mg/__init__.py b/firedrake/mg/__init__.py index 12f8ec7c36..d34a6d128c 100644 --- a/firedrake/mg/__init__.py +++ b/firedrake/mg/__init__.py @@ -9,3 +9,6 @@ from firedrake.mg.opencascade_mh import OpenCascadeMeshHierarchy # noqa F401 from firedrake.mg.adaptive_hierarchy import AdaptiveMeshHierarchy # noqa F401 from firedrake.mg.adaptive_transfer_manager import AdaptiveTransferManager # noqa: F401 +from firedrake.mg.robust_transfer_manager import ( # noqa: F401 + CoarsePatchTransferManager, FinePatchTransferManager, +) diff --git a/firedrake/mg/mesh.py b/firedrake/mg/mesh.py index 2279cbc67d..5bb183ddcd 100644 --- a/firedrake/mg/mesh.py +++ b/firedrake/mg/mesh.py @@ -10,6 +10,7 @@ from functools import cached_property from firedrake import utils +from firedrake.petsc import PETSc from firedrake.cython import mgimpl as impl from .utils import set_level @@ -28,6 +29,8 @@ class HierarchyBase(object): :arg refinements_per_level: number of mesh refinements each multigrid level should "see". :arg nested: Is this mesh hierarchy nested? + :arg coarse_facet_label: Optional subdomain ID to label the coarse facets on + each level of the hierarchy. .. note:: @@ -36,7 +39,7 @@ class HierarchyBase(object): :func:`ExtrudedMeshHierarchy`, or :func:`NonNestedHierarchy`. """ def __init__(self, meshes, coarse_to_fine_cells, fine_to_coarse_cells, - refinements_per_level=1, nested=False): + refinements_per_level=1, nested=False, coarse_facet_label=None): petsctools.cite("Mitchell2016") self._meshes = tuple(meshes) self.meshes = tuple(meshes[::refinements_per_level]) @@ -44,6 +47,7 @@ def __init__(self, meshes, coarse_to_fine_cells, fine_to_coarse_cells, self.fine_to_coarse_cells = fine_to_coarse_cells self.refinements_per_level = refinements_per_level self.nested = nested + self._coarse_facet_label = coarse_facet_label for level, m in enumerate(meshes): set_level(m, self, Fraction(level, refinements_per_level)) for level, m in enumerate(self): @@ -78,7 +82,8 @@ def MeshHierarchy(mesh, refinement_levels, netgen_flags=False, reorder=None, distribution_parameters=None, callbacks=None, - mesh_builder=firedrake.Mesh): + mesh_builder=firedrake.Mesh, + coarse_facet_label=None): """Build a hierarchy of meshes by uniformly refining a coarse mesh. Parameters @@ -108,6 +113,10 @@ def MeshHierarchy(mesh, refinement_levels, callback receives the refined DM (and the current level). mesh_builder Function to turn a DM into a ``Mesh``. Used by pyadjoint. + coarse_facet_label : int | None + Optional subdomain ID to label the coarse facets on each + level of the hierarchy. + Returns ------- A :py:class:`HierarchyBase` object representing the @@ -138,11 +147,29 @@ def MeshHierarchy(mesh, refinement_levels, else: before = after = lambda dm, i: None for i in range(refinement_levels*refinements_per_level): + if coarse_facet_label is not None: + # Create a temporary label on all the facets of the coarse dm + # to label every coarse facet on the fine dm + fstart, fend = cdm.getHeightStratum(1) + iset = PETSc.IS().createStride(fend-fstart, first=fstart, comm=cdm.comm) + cdm.createLabel("temp_label") + label = cdm.getLabel("temp_label") + label.setStratumIS(1, iset) + if i % refinements_per_level == 0: before(cdm, i) rdm = cdm.refine() if i % refinements_per_level == 0: after(rdm, i) + + if coarse_facet_label is not None: + # Move coarse_facet_label into FACE_SETS_LABEL + iset = rdm.getLabel("temp_label").getStratumIS(1) + label = rdm.getLabel(dmcommon.FACE_SETS_LABEL) + label.setStratumIS(coarse_facet_label, iset) + rdm.removeLabel("temp_label") + cdm.removeLabel("temp_label") + dms.append(rdm) cdm = rdm # Fix up coords if refining embedded circle or sphere @@ -191,7 +218,8 @@ def MeshHierarchy(mesh, refinement_levels, fine_to_coarse_cells = dict((Fraction(i, refinements_per_level), f2c) for i, f2c in enumerate(fine_to_coarse_cells)) return HierarchyBase(meshes, coarse_to_fine_cells, fine_to_coarse_cells, - refinements_per_level, nested=True) + refinements_per_level, nested=True, + coarse_facet_label=coarse_facet_label) def ExtrudedMeshHierarchy(base_hierarchy, height, base_layer=-1, refinement_ratio=2, layers=None, diff --git a/firedrake/mg/robust_transfer_manager.py b/firedrake/mg/robust_transfer_manager.py new file mode 100644 index 0000000000..3163009903 --- /dev/null +++ b/firedrake/mg/robust_transfer_manager.py @@ -0,0 +1,296 @@ +from functools import partial +from ufl import H1 +from finat.ufl import FiniteElement, NodalEnrichedElement, TensorElement + +from firedrake import dmhooks +from firedrake.assemble import assemble, get_assembler +from firedrake.bcs import DirichletBC, restricted_function_space +from firedrake.function import Function +from firedrake.interpolation import interpolate, get_interpolator +from firedrake.slate import Inverse, Tensor +from firedrake.ufl_expr import action, TestFunction, TrialFunction +from firedrake.utils import complex_mode +from firedrake.variational_solver import LinearVariationalProblem, LinearVariationalSolver +from .embedded import TransferManager +from .utils import get_level + + +__all__ = ("CoarsePatchTransferManager", "FinePatchTransferManager", "RobustTransferManager") + + +class RobustTransferManager(TransferManager): + """An object for managing transfers between levels in a multigrid hierarchy + via standard interpolation into subdomain boundaries followed by an extension + into the interior of the subdomains by solving the homogeneous PDE. + + :kwarg native_transfers: dict mapping UFL element + to "natively supported" transfer operators. This should be + a three-tuple of (prolong, restrict, inject). + :kwarg use_averaging: Use averaging to approximate the + projection out of the embedded DG space? If False, a global + L2 projection will be performed. + """ + + class TransferCallable: + """Internal class to apply a sequence on linear operations + by transfering the input and output into local buffers + referenced in the list of callables. + """ + def __init__(self, x_buffer, y_buffer, callables): + self.x_buffer = x_buffer + self.y_buffer = y_buffer + self.callables = callables + + def __call__(self, x, y): + self.x_buffer.assign(x) + for c in self.callables: + c() + return y.assign(self.y_buffer) + + def __init__(self, native_transfers=None, use_averaging=True): + super().__init__(native_transfers=native_transfers, + use_averaging=use_averaging) + self.direct_solver_parameters = { + "ksp_type": "preonly", + "pc_type": "bjacobi", + "sub_pc_type": "cholesky", + "sub_pc_factor_mat_solver_type": "cholmod", + } + + def form(self, V): + """Get the preconditioning Form in the solver context of a FunctionSpace.""" + ctx = dmhooks.get_appctx(V.dm) + if ctx is not None: + return ctx._problem.Jp or ctx._problem.J + else: + return None + + def auxiliary_target_space(self, V): + """Construct an auxiliary target FunctionSpace.""" + raise NotImplementedError("Must be implemented by subclass.") + + def build_patch_solver(self, form, V): + """Build a solver to extend the solution from the residual in the + auxiliary space into the entire space V.""" + raise NotImplementedError("Must be implemented by subclass.") + + def get_patch_solver(self, form, V): + """Cache the patch solver.""" + cache = form._cache + key = (type(self).__name__, "patch_solver") + try: + return cache[key] + except KeyError: + return cache.setdefault(key, self.build_patch_solver(form, V)) + + def build_transfer_callables(self, form, Vc, Vf): + """Construct prolongation and restriction TransferCallables.""" + uc = Function(Vc) + uf = Function(Vf) + P = self.prolong_callable(form, uc, uf) + rc = Function(Vc.dual(), val=uc.dat) + rf = Function(Vf.dual(), val=uf.dat) + R = self.restrict_callable(form, rf, rc) + return P, R + + def get_transfer_callables(self, Vc, Vf): + """Cache the prolongation and restriction TransferCallables.""" + form = self.form(Vf) + cache = form._cache + key = (type(self).__name__, "transfer_callables") + try: + return cache[key] + except KeyError: + return cache.setdefault(key, self.build_transfer_callables(form, Vc, Vf)) + + def prolong_callable(self, form, uc, uf): + """Return a TransferCallable that interpolates uc into uf such that + uc = uf on patch boundaries and form(v, uf) = 0 for all v on the patch + subspaces.""" + V = uf.function_space() + V_aux = self.auxiliary_target_space(V) + u_aux = Function(V_aux) + + solver, r_patch, u_patch = self.get_patch_solver(form, V) + if solver is None: + # patch problem is empty + callables = ( + partial(TransferManager.prolong, self, uc, u_aux), + partial(u_aux.dat.copy, uf.dat), + ) + else: + btest, = r_patch.arguments() + if len(set(f.ufl_element() for f in (uf, u_aux, u_patch))) == 1: + copy_update = partial(uf.assign, u_aux - u_patch) + else: + wtest = TestFunction(V.dual()) + Iv = get_interpolator(interpolate(u_aux - u_patch, wtest)) + copy_update = partial(Iv.assemble, tensor=uf) + + residual = get_assembler(form(btest, u_aux)) + callables = ( + partial(TransferManager.prolong, self, uc, u_aux), + partial(residual.assemble, tensor=r_patch), + solver, + copy_update, + ) + return self.TransferCallable(uc, uf, callables) + + def restrict_callable(self, form, rf, rc): + """Return a TransferCallable with the adjoint of prolong.""" + V = rf.function_space().dual() + V_aux = self.auxiliary_target_space(V) + r_aux = Function(V_aux.dual()) + Au = Function(V_aux.dual()) + + solver, r_patch, u_patch = self.get_patch_solver(form, V) + if solver is None: + # patch problem is empty + callables = ( + partial(rf.dat.copy, r_aux.dat), + partial(TransferManager.restrict, self, r_aux, rc), + ) + else: + btest, = r_patch.arguments() + vtest = TestFunction(V_aux) + if len(set(f.ufl_element() for f in (rf, r_aux, r_patch))) == 1: + copy_aux = partial(r_aux.assign, rf) + copy_rhs = partial(r_patch.assign, rf) + else: + Iv = get_interpolator(interpolate(vtest, rf)) + Ib = get_interpolator(interpolate(btest, rf)) + copy_aux = partial(Iv.assemble, tensor=r_aux) + copy_rhs = partial(Ib.assemble, tensor=r_patch) + + residual = get_assembler(form(u_patch, vtest)) + callables = ( + copy_rhs, + solver, + partial(residual.assemble, tensor=Au), + copy_aux, + partial(r_aux.assign, r_aux - Au), + partial(TransferManager.restrict, self, r_aux, rc), + ) + return self.TransferCallable(rf, rc, callables) + + def prolong(self, uc, uf): + Vc = uc.function_space() + Vf = uf.function_space() + form = self.form(Vf) + if form is not None: + P, R = self.get_transfer_callables(Vc, Vf) + return P(uc, uf) + else: + return super().prolong(uc, uf) + + def restrict(self, rf, rc): + Vc = rc.function_space().dual() + Vf = rf.function_space().dual() + form = self.form(Vf) + if form is not None: + P, R = self.get_transfer_callables(Vc, Vf) + return R(rf, rc) + else: + return super().restrict(rf, rc) + + +class CoarsePatchTransferManager(RobustTransferManager): + """An object for managing transfers between levels in a multigrid hierarchy + via standard interpolation into coarse cell boundaries followed by an extension + into the interior of the coarse cell patches by solving the homogeneous PDE. + + This class will raise an error when the coarse facets are not labeled across + the MeshHierarchy. + """ + + def auxiliary_target_space(self, U): + """Construct a standard space for inter-grid interpolation.""" + return U.reconstruct(variant=None, quad_scheme=None) + + def build_patch_solver(self, form, V): + """Solve form(test, u_patch) = r_patch on coarse cell patches.""" + mesh = V.mesh().unique() + marker = self.get_coarse_facet_label(mesh) + V_patch = restricted_function_space(V, [(marker,)]) + u_patch = Function(V_patch) + r_patch = Function(V_patch.dual()) + test = TestFunction(V_patch) + trial = TrialFunction(V_patch) + + bcs = DirichletBC(V_patch, 0, marker) + a = assemble(form(test, trial), bcs=bcs) + problem = LinearVariationalProblem(a, r_patch, u_patch) + solver = LinearVariationalSolver(problem, solver_parameters=self.direct_solver_parameters) + return (solver.solve, r_patch, u_patch) + + def get_coarse_facet_label(self, mesh): + markers = tuple(mesh.interior_facets.unique_markers) + mh, _ = get_level(mesh) + label = mh._coarse_facet_label + if label is not None and label in markers: + return label + raise ValueError("Expecting a hierarchy with a coarse facet label.") + + +class FinePatchTransferManager(RobustTransferManager): + """An object for managing transfers between levels in a multigrid hierarchy + via standard interpolation into fine cell boundaries followed by an extension + into the interior of the fine cells by solving the homogeneous PDE. + """ + + def auxiliary_target_space(self, U): + """Construct a facet space for inter-grid interpolation.""" + quad_scheme = None + element = U.ufl_element() + if U.finat_element.complex.is_macrocell(): + # Macroelements require a composite quadrature scheme + if element.sobolev_space == H1 and U.finat_element.degree < 4: + quad_scheme = "powell-sabin,KMV(2)" + else: + quad_scheme = "powell-sabin" + + tdim = U.mesh().topological_dimension + if U.finat_element.has_pointwise_dual_basis and U.finat_element.degree == tdim: + # Facet moment degrees of freedom for CG elements + CG = FiniteElement("CG", degree=tdim, variant="chebyshev") + CR = FiniteElement("CR", degree=1, variant="integral", quad_scheme=quad_scheme) + element = NodalEnrichedElement(CG["ridge"], CR) + if U.value_shape != (): + element = TensorElement(element, shape=U.value_shape) + else: + # Take the facet element with the new quadrature scheme + if quad_scheme is not None: + element = element.reconstruct(quad_scheme=quad_scheme) + element = element["facet"] + return U.reconstruct(element=element) + + def build_patch_solver(self, form, V): + """Solve form(test, u_patch) = r_patch on fine cell patches""" + # Reconstruct the space on the interior without a special quadrature + + tdim = V.mesh().topological_dimension + entity_dofs = V.finat_element.entity_dofs() + if len(entity_dofs[tdim][0]) == 0: + return (None, None, None) + + element = V.ufl_element() + if element._quad_scheme is not None: + element = element.reconstruct(quad_scheme=None) + V_patch = V.reconstruct(element=element["interior"]) + u_patch = Function(V_patch) + r_patch = Function(V_patch.dual()) + test = TestFunction(V_patch) + trial = TrialFunction(V_patch) + a = form(test, trial) + + use_slate_for_inverse = not complex_mode + if use_slate_for_inverse: + ainv = assemble(Inverse(Tensor(a))) + assembler = get_assembler(action(ainv, r_patch)) + solve = partial(assembler.assemble, tensor=u_patch) + else: + a = assemble(a) + problem = LinearVariationalProblem(a, r_patch, u_patch) + solver = LinearVariationalSolver(problem, solver_parameters=self.direct_solver_parameters) + solve = solver.solve + return (solve, r_patch, u_patch) diff --git a/firedrake/mg/ufl_utils.py b/firedrake/mg/ufl_utils.py index d5a26a36e6..cf021d0718 100644 --- a/firedrake/mg/ufl_utils.py +++ b/firedrake/mg/ufl_utils.py @@ -112,6 +112,13 @@ def coarsen_form(form, self, coefficient_mapping=None): return form +@coarsen.register(ufl.Interpolate) +def coarsen_interpolate(interp, self, coefficient_mapping=None): + dual_arg, operand = interp.argument_slots() + return interp._ufl_expr_reconstruct_(self(operand, self, coefficient_mapping=coefficient_mapping), + self(dual_arg, self, coefficient_mapping=coefficient_mapping)) + + @coarsen.register(ufl.FormSum) def coarsen_formsum(form, self, coefficient_mapping=None): return type(form)(*[(self(ci, self, coefficient_mapping=coefficient_mapping), @@ -156,6 +163,7 @@ def coarsen_function_space(V, self, coefficient_mapping=None): V_coarse = V_fine.reconstruct(mesh=mesh_coarse, name=name) V_coarse._fine = V_fine V_fine._coarse = V_coarse + return V_coarse @@ -310,15 +318,13 @@ def coarsen_snescontext(context, self, coefficient_mapping=None): coarse_sub_mat_type = context.sub_mat_type coarse_sub_pmat_type = context.sub_pmat_type - coarse = _SNESContext(problem, - mat_type=coarse_mat_type, - pmat_type=coarse_pmat_type, - sub_mat_type=coarse_sub_mat_type, - sub_pmat_type=coarse_sub_pmat_type, - appctx=new_appctx, - options_prefix=context.options_prefix, - transfer_manager=context.transfer_manager, - pre_apply_bcs=context.pre_apply_bcs) + coarse = context.reconstruct(problem=problem, + mat_type=coarse_mat_type, + pmat_type=coarse_pmat_type, + sub_mat_type=coarse_sub_mat_type, + sub_pmat_type=coarse_sub_pmat_type, + appctx=new_appctx, + ) coarse._coefficient_mapping = coefficient_mapping coarse._fine = context context._coarse = coarse diff --git a/tests/firedrake/multigrid/test_robust_transfer.py b/tests/firedrake/multigrid/test_robust_transfer.py new file mode 100644 index 0000000000..7b4637bc4b --- /dev/null +++ b/tests/firedrake/multigrid/test_robust_transfer.py @@ -0,0 +1,80 @@ +import pytest +from firedrake import * + + +@pytest.fixture +def hierarchy(): + distribution_parameters = {"overlap_type": (DistributedMeshOverlapType.VERTEX, 1)} + nx = 4 + refine = 3 + base = UnitSquareMesh(nx, nx, distribution_parameters=distribution_parameters) + mh = MeshHierarchy(base, refine, coarse_facet_label=1000) + return mh + + +@pytest.fixture +def mesh(hierarchy): + return hierarchy[-1] + + +@pytest.fixture +def V(mesh): + degree = mesh.topological_dimension + V = VectorFunctionSpace(mesh, "CG", degree, variant="alfeld") + return V + + +@pytest.fixture +def solver(V): + uh = Function(V) + u = TrialFunction(V) + v = TestFunction(V) + x = SpatialCoordinate(V.mesh()) + uexact = x * sum(x) + + mu = Constant(1) + lam = Constant(1E4) + eps = lambda u: sym(grad(u)) + a = inner(2*mu*eps(u), eps(v))*dx + inner(lam*div(u), div(v))*dx + L = a(v, uexact) + bcs = DirichletBC(V, uexact, "on_boundary") + + solver_parameters = { + "mat_type": "aij", + "snes_type": "ksponly", + "ksp_type": "cg", + "ksp_rtol": 1e-8, + "ksp_monitor": None, + "pc_type": "mg", + "mg_levels": { + "ksp_type": "chebyshev", + "ksp_max_it": 2, + "pc_type": "python", + "pc_python_type": "firedrake.ASMStarPC", + "pc_star_sub_sub_pc_type": "cholesky", + "pc_star_sub_sub_pc_factor_mat_solver_type": "petsc", + "pc_star_mat_ordering_type": "nd", + "pc_star_use_coloring": True, + }, + "mg_coarse": { + "mat_type": "aij", + "pc_type": "cholesky", + "pc_factor_mat_solver_type": "mumps", + } + } + + problem = LinearVariationalProblem(a, L, uh, bcs=bcs) + solver = LinearVariationalSolver(problem, + solver_parameters=solver_parameters) + return solver + + +@pytest.mark.parallel([1, 3]) +@pytest.mark.parametrize("create_transfer", [CoarsePatchTransferManager, FinePatchTransferManager]) +def test_robust_transfer(solver, create_transfer): + tm = create_transfer() + u = solver._problem.u + u.zero() + solver.set_transfer_manager(tm) + solver.solve() + assert solver.snes.ksp.getIterationNumber() < 15