From aebe6cd0edc13c9486e920e99d4eacaa85584eec Mon Sep 17 00:00:00 2001 From: ksagiyam Date: Fri, 12 Apr 2024 02:34:36 +0100 Subject: [PATCH] Enable solving multi-domain problems involving codim-0 submeshes --- .../saddle_point_systems.py.rst | 12 +- firedrake/assemble.py | 243 ++++++--- firedrake/bcs.py | 2 - firedrake/checkpointing.py | 12 +- firedrake/dmhooks.py | 17 +- firedrake/ensemble/ensemble_functionspace.py | 2 +- firedrake/function.py | 12 +- firedrake/functionspace.py | 21 +- firedrake/functionspaceimpl.py | 81 ++- firedrake/mesh.py | 195 ++++++++ firedrake/mg/kernels.py | 8 +- firedrake/mg/ufl_utils.py | 5 +- firedrake/pointeval_utils.py | 2 +- firedrake/pointquery_utils.py | 3 +- firedrake/preconditioners/asm.py | 34 +- firedrake/preconditioners/fdm.py | 2 +- firedrake/preconditioners/patch.py | 121 ++--- firedrake/preconditioners/pmg.py | 12 +- firedrake/slate/slac/compiler.py | 10 +- firedrake/slate/slac/kernel_builder.py | 39 +- firedrake/slate/slate.py | 15 +- .../static_condensation/hybridization.py | 16 +- firedrake/slate/static_condensation/scpc.py | 2 +- firedrake/tsfc_interface.py | 33 +- firedrake/ufl_expr.py | 88 ++-- .../regression/test_assemble_baseform.py | 4 +- .../regression/test_function_spaces.py | 4 +- .../regression/test_multiple_domains.py | 12 +- .../submesh/test_submesh_assemble.py | 328 +++++++++++++ tests/firedrake/submesh/test_submesh_base.py | 275 +++++++++++ tests/firedrake/submesh/test_submesh_solve.py | 460 ++++++++++++++++++ tests/tsfc/test_tsfc_182.py | 5 +- tests/tsfc/test_tsfc_204.py | 5 +- tsfc/driver.py | 35 +- tsfc/fem.py | 90 ++-- tsfc/kernel_args.py | 4 +- tsfc/kernel_interface/__init__.py | 12 +- tsfc/kernel_interface/common.py | 162 +++--- tsfc/kernel_interface/firedrake_loopy.py | 363 +++++++++----- tsfc/ufl_utils.py | 11 +- 40 files changed, 2241 insertions(+), 516 deletions(-) create mode 100644 tests/firedrake/submesh/test_submesh_assemble.py create mode 100644 tests/firedrake/submesh/test_submesh_base.py create mode 100644 tests/firedrake/submesh/test_submesh_solve.py diff --git a/demos/saddle_point_pc/saddle_point_systems.py.rst b/demos/saddle_point_pc/saddle_point_systems.py.rst index 752836626a..f9138be076 100644 --- a/demos/saddle_point_pc/saddle_point_systems.py.rst +++ b/demos/saddle_point_pc/saddle_point_systems.py.rst @@ -180,7 +180,7 @@ Finally, at each mesh size, we print out the number of cells in the mesh and the number of iterations the solver took to converge :: # - print(w.function_space().mesh().num_cells(), solver.snes.ksp.getIterationNumber()) + print(w.function_space().mesh().unique().num_cells(), solver.snes.ksp.getIterationNumber()) The resulting convergence is unimpressive: @@ -282,7 +282,7 @@ applying the action of blocks, so we can use a block matrix format. :: for n in range(8): solver, w = build_problem(n, parameters, block_matrix=True) solver.solve() - print(w.function_space().mesh().num_cells(), solver.snes.ksp.getIterationNumber()) + print(w.function_space().mesh().unique().num_cells(), solver.snes.ksp.getIterationNumber()) The resulting convergence is algorithmically good, however, the larger problems still take a long time. @@ -367,7 +367,7 @@ Let's see what happens. :: for n in range(8): solver, w = build_problem(n, parameters, block_matrix=True) solver.solve() - print(w.function_space().mesh().num_cells(), solver.snes.ksp.getIterationNumber()) + print(w.function_space().mesh().unique().num_cells(), solver.snes.ksp.getIterationNumber()) This is much better, the problem takes much less time to solve and when observing the iteration counts for inverting :math:`S` we can see @@ -422,7 +422,7 @@ and so we no longer need a flexible Krylov method. :: for n in range(8): solver, w = build_problem(n, parameters, block_matrix=True) solver.solve() - print(w.function_space().mesh().num_cells(), solver.snes.ksp.getIterationNumber()) + print(w.function_space().mesh().unique().num_cells(), solver.snes.ksp.getIterationNumber()) This results in the following GMRES iteration counts @@ -487,7 +487,7 @@ variable. We can provide it as an :class:`~.AuxiliaryOperatorPC` via a python pr for n in range(8): solver, w = build_problem(n, parameters, aP=None, block_matrix=False) solver.solve() - print(w.function_space().mesh().num_cells(), solver.snes.ksp.getIterationNumber()) + print(w.function_space().mesh().unique().num_cells(), solver.snes.ksp.getIterationNumber()) This actually results in slightly worse convergence than the diagonal approximation we used above. @@ -571,7 +571,7 @@ Let's see what the iteration count looks like now. :: for n in range(8): solver, w = build_problem(n, parameters, aP=riesz, block_matrix=True) solver.solve() - print(w.function_space().mesh().num_cells(), solver.snes.ksp.getIterationNumber()) + print(w.function_space().mesh().unique().num_cells(), solver.snes.ksp.getIterationNumber()) ============== ================== Mesh elements GMRES iterations diff --git a/firedrake/assemble.py b/firedrake/assemble.py index 2c3707cb74..e21fbb4bac 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -19,7 +19,7 @@ from firedrake import (extrusion_utils as eutils, matrix, parameters, solving, tsfc_interface, utils) from firedrake.adjoint_utils import annotate_assemble -from firedrake.ufl_expr import extract_unique_domain +from firedrake.ufl_expr import extract_domains from firedrake.bcs import DirichletBC, EquationBC, EquationBCSplit from firedrake.functionspaceimpl import WithGeometry, FunctionSpace, FiredrakeDualSpace from firedrake.functionspacedata import entity_dofs_key, entity_permutations_key @@ -1027,7 +1027,7 @@ def parloops(self, tensor): self._bcs, local_kernel, subdomain_id, - self.all_integer_subdomain_ids[local_kernel.indices], + self.all_integer_subdomain_ids[local_kernel.indices][local_kernel.kinfo.domain_number], diagonal=self.diagonal, ) pyop2_tensor = self._as_pyop2_type(tensor, local_kernel.indices) @@ -1049,14 +1049,15 @@ def local_kernels(self): """ try: - topology, = set(d.topology for d in self._form.ufl_domains()) + topology, = set(d.topology.submesh_ancesters[-1] for d in self._form.ufl_domains()) except ValueError: raise NotImplementedError("All integration domains must share a mesh topology") for o in itertools.chain(self._form.arguments(), self._form.coefficients()): - domain = extract_unique_domain(o) - if domain is not None and domain.topology != topology: - raise NotImplementedError("Assembly with multiple meshes is not supported") + domains = extract_domains(o) + for domain in domains: + if domain is not None and domain.topology.submesh_ancesters[-1] != topology: + raise NotImplementedError("Assembly with multiple meshes is not supported") if isinstance(self._form, ufl.Form): kernels = tsfc_interface.compile_form( @@ -1368,12 +1369,12 @@ def _make_maps_and_regions(self): else: maps_and_regions = defaultdict(lambda: defaultdict(set)) for assembler in self._all_assemblers: - all_meshes = assembler._form.ufl_domains() + all_meshes = extract_domains(assembler._form) for local_kernel, subdomain_id in assembler.local_kernels: i, j = local_kernel.indices mesh = all_meshes[local_kernel.kinfo.domain_number] # integration domain integral_type = local_kernel.kinfo.integral_type - all_subdomain_ids = assembler.all_integer_subdomain_ids[local_kernel.indices] + all_subdomain_ids = assembler.all_integer_subdomain_ids[local_kernel.indices][local_kernel.kinfo.domain_number] # Make Sparsity independent of the subdomain of integration for better reusability; # subdomain_id is passed here only to determine the integration_type on the target domain # (see ``entity_node_map``). @@ -1549,6 +1550,10 @@ def _global_kernel_cache_key(form, local_knl, subdomain_id, all_integer_subdomai # N.B. Generating the global kernel is not a collective operation so the # communicator does not need to be a part of this cache key. + # Maps in the cached global kernel depend on concrete mesh data. + all_meshes = extract_domains(form) + domain_ids = tuple(mesh.ufl_id() for mesh in all_meshes) + if isinstance(form, ufl.Form): sig = form.signature() elif isinstance(form, slate.TensorBase): @@ -1568,7 +1573,8 @@ def _global_kernel_cache_key(form, local_knl, subdomain_id, all_integer_subdomai else: subdomain_key.append((k, i)) - return ((sig, subdomain_id) + return (domain_ids + + (sig, subdomain_id) + tuple(subdomain_key) + tuplify(all_integer_subdomain_ids) + cachetools.keys.hashkey(local_knl, **kwargs)) @@ -1605,8 +1611,15 @@ def __init__(self, form, local_knl, subdomain_id, all_integer_subdomain_ids, dia self._diagonal = diagonal self._unroll = unroll + self._active_coordinates = _FormHandler.iter_active_coordinates(form, local_knl.kinfo) + self._active_cell_orientations = _FormHandler.iter_active_cell_orientations(form, local_knl.kinfo) + self._active_cell_sizes = _FormHandler.iter_active_cell_sizes(form, local_knl.kinfo) self._active_coefficients = _FormHandler.iter_active_coefficients(form, local_knl.kinfo) self._constants = _FormHandler.iter_constants(form, local_knl.kinfo) + self._active_exterior_facets = _FormHandler.iter_active_exterior_facets(form, local_knl.kinfo) + self._active_interior_facets = _FormHandler.iter_active_interior_facets(form, local_knl.kinfo) + self._active_orientations_exterior_facet = _FormHandler.iter_active_orientations_exterior_facet(form, local_knl.kinfo) + self._active_orientations_interior_facet = _FormHandler.iter_active_orientations_interior_facet(form, local_knl.kinfo) self._map_arg_cache = {} # Cache for holding :class:`op2.MapKernelArg` instances. @@ -1620,8 +1633,15 @@ def build(self): for arg in self._kinfo.arguments] # we should use up all of the coefficients and constants + assert_empty(self._active_coordinates) + assert_empty(self._active_cell_orientations) + assert_empty(self._active_cell_sizes) assert_empty(self._active_coefficients) assert_empty(self._constants) + assert_empty(self._active_exterior_facets) + assert_empty(self._active_interior_facets) + assert_empty(self._active_orientations_exterior_facet) + assert_empty(self._active_orientations_interior_facet) iteration_regions = {"exterior_facet_top": op2.ON_TOP, "exterior_facet_bottom": op2.ON_BOTTOM, @@ -1646,7 +1666,8 @@ def _integral_type(self): @cached_property def _mesh(self): - return self._form.ufl_domains()[self._kinfo.domain_number] + all_meshes = extract_domains(self._form) + return all_meshes[self._kinfo.domain_number] @cached_property def _needs_subset(self): @@ -1751,7 +1772,22 @@ def _as_global_kernel_arg_output(_, self): @_as_global_kernel_arg.register(kernel_args.CoordinatesKernelArg) def _as_global_kernel_arg_coordinates(_, self): - V = self._mesh.coordinates.function_space() + coord = next(self._active_coordinates) + V = coord.function_space() + return self._make_dat_global_kernel_arg(V) + + +@_as_global_kernel_arg.register(kernel_args.CellOrientationsKernelArg) +def _as_global_kernel_arg_cell_orientations(_, self): + c = next(self._active_cell_orientations) + V = c.function_space() + return self._make_dat_global_kernel_arg(V) + + +@_as_global_kernel_arg.register(kernel_args.CellSizesKernelArg) +def _as_global_kernel_arg_cell_sizes(_, self): + c = next(self._active_cell_sizes) + V = c.function_space() return self._make_dat_global_kernel_arg(V) @@ -1779,30 +1815,48 @@ def _as_global_kernel_arg_constant(_, self): return op2.GlobalKernelArg((value_size,)) -@_as_global_kernel_arg.register(kernel_args.CellSizesKernelArg) -def _as_global_kernel_arg_cell_sizes(_, self): - V = self._mesh.cell_sizes.function_space() - return self._make_dat_global_kernel_arg(V) - - @_as_global_kernel_arg.register(kernel_args.ExteriorFacetKernelArg) def _as_global_kernel_arg_exterior_facet(_, self): - return op2.DatKernelArg((1,)) + mesh = next(self._active_exterior_facets) + if mesh is self._mesh: + return op2.DatKernelArg((1,)) + else: + m, integral_type = mesh.topology.trans_mesh_entity_map(self._mesh.topology, self._integral_type, self._subdomain_id, self._all_integer_subdomain_ids) + assert integral_type == "exterior_facet" + return op2.DatKernelArg((1,), m._global_kernel_arg) @_as_global_kernel_arg.register(kernel_args.InteriorFacetKernelArg) def _as_global_kernel_arg_interior_facet(_, self): - return op2.DatKernelArg((2,)) + mesh = next(self._active_interior_facets) + if mesh is self._mesh: + return op2.DatKernelArg((2,)) + else: + m, integral_type = mesh.topology.trans_mesh_entity_map(self._mesh.topology, self._integral_type, self._subdomain_id, self._all_integer_subdomain_ids) + assert integral_type == "interior_facet" + return op2.DatKernelArg((2,), m._global_kernel_arg) -@_as_global_kernel_arg.register(kernel_args.ExteriorFacetOrientationKernelArg) -def _as_global_kernel_arg_exterior_facet_orientation(_, self): - return op2.DatKernelArg((1,)) +@_as_global_kernel_arg.register(kernel_args.OrientationsExteriorFacetKernelArg) +def _(_, self): + mesh = next(self._active_orientations_exterior_facet) + if mesh is self._mesh: + return op2.DatKernelArg((1,)) + else: + m, integral_type = mesh.topology.trans_mesh_entity_map(self._mesh.topology, self._integral_type, self._subdomain_id, self._all_integer_subdomain_ids) + assert integral_type == "exterior_facet" + return op2.DatKernelArg((1,), m._global_kernel_arg) -@_as_global_kernel_arg.register(kernel_args.InteriorFacetOrientationKernelArg) -def _as_global_kernel_arg_interior_facet_orientation(_, self): - return op2.DatKernelArg((2,)) +@_as_global_kernel_arg.register(kernel_args.OrientationsInteriorFacetKernelArg) +def _(_, self): + mesh = next(self._active_orientations_interior_facet) + if mesh is self._mesh: + return op2.DatKernelArg((2,)) + else: + m, integral_type = mesh.topology.trans_mesh_entity_map(self._mesh.topology, self._integral_type, self._subdomain_id, self._all_integer_subdomain_ids) + assert integral_type == "interior_facet" + return op2.DatKernelArg((2,), m._global_kernel_arg) @_as_global_kernel_arg.register(CellFacetKernelArg) @@ -1814,12 +1868,6 @@ def _as_global_kernel_arg_cell_facet(_, self): return op2.DatKernelArg((num_facets, 2)) -@_as_global_kernel_arg.register(kernel_args.CellOrientationsKernelArg) -def _as_global_kernel_arg_cell_orientations(_, self): - V = self._mesh.cell_orientations().function_space() - return self._make_dat_global_kernel_arg(V) - - @_as_global_kernel_arg.register(LayerCountKernelArg) def _as_global_kernel_arg_layer_count(_, self): return op2.GlobalKernelArg((1,)) @@ -1853,8 +1901,15 @@ def __init__(self, form, bcs, local_knl, subdomain_id, self._diagonal = diagonal self._bcs = bcs + self._active_coordinates = _FormHandler.iter_active_coordinates(form, local_knl.kinfo) + self._active_cell_orientations = _FormHandler.iter_active_cell_orientations(form, local_knl.kinfo) + self._active_cell_sizes = _FormHandler.iter_active_cell_sizes(form, local_knl.kinfo) self._active_coefficients = _FormHandler.iter_active_coefficients(form, local_knl.kinfo) self._constants = _FormHandler.iter_constants(form, local_knl.kinfo) + self._active_exterior_facets = _FormHandler.iter_active_exterior_facets(form, local_knl.kinfo) + self._active_interior_facets = _FormHandler.iter_active_interior_facets(form, local_knl.kinfo) + self._active_orientations_exterior_facet = _FormHandler.iter_active_orientations_exterior_facet(form, local_knl.kinfo) + self._active_orientations_interior_facet = _FormHandler.iter_active_orientations_interior_facet(form, local_knl.kinfo) def build(self, tensor: op2.Global | op2.Dat | op2.Mat) -> op2.Parloop: """Construct the parloop. @@ -1988,7 +2043,8 @@ def _indexed_function_spaces(self): @cached_property def _mesh(self): - return self._form.ufl_domains()[self._kinfo.domain_number] + all_meshes = extract_domains(self._form) + return all_meshes[self._kinfo.domain_number] @cached_property def _iterset(self): @@ -2060,7 +2116,21 @@ def _as_parloop_arg_output(_, self): @_as_parloop_arg.register(kernel_args.CoordinatesKernelArg) def _as_parloop_arg_coordinates(_, self): - func = self._mesh.coordinates + func = next(self._active_coordinates) + map_ = self._get_map(func.function_space()) + return op2.DatParloopArg(func.dat, map_) + + +@_as_parloop_arg.register(kernel_args.CellOrientationsKernelArg) +def _as_parloop_arg_cell_orientations(_, self): + func = next(self._active_cell_orientations) + map_ = self._get_map(func.function_space()) + return op2.DatParloopArg(func.dat, map_) + + +@_as_parloop_arg.register(kernel_args.CellSizesKernelArg) +def _as_parloop_arg_cell_sizes(_, self): + func = next(self._active_cell_sizes) map_ = self._get_map(func.function_space()) return op2.DatParloopArg(func.dat, map_) @@ -2081,38 +2151,48 @@ def _as_parloop_arg_constant(arg, self): return op2.GlobalParloopArg(const.dat) -@_as_parloop_arg.register(kernel_args.CellOrientationsKernelArg) -def _as_parloop_arg_cell_orientations(_, self): - func = self._mesh.cell_orientations() - m = self._get_map(func.function_space()) - return op2.DatParloopArg(func.dat, m) - - -@_as_parloop_arg.register(kernel_args.CellSizesKernelArg) -def _as_parloop_arg_cell_sizes(_, self): - func = self._mesh.cell_sizes - m = self._get_map(func.function_space()) - return op2.DatParloopArg(func.dat, m) - - @_as_parloop_arg.register(kernel_args.ExteriorFacetKernelArg) def _as_parloop_arg_exterior_facet(_, self): - return op2.DatParloopArg(self._mesh.exterior_facets.local_facet_dat) + mesh = next(self._active_exterior_facets) + if mesh is self._mesh: + m = None + else: + m, integral_type = mesh.topology.trans_mesh_entity_map(self._mesh.topology, self._integral_type, self._subdomain_id, self._all_integer_subdomain_ids) + assert integral_type == "exterior_facet" + return op2.DatParloopArg(mesh.exterior_facets.local_facet_dat, m) @_as_parloop_arg.register(kernel_args.InteriorFacetKernelArg) def _as_parloop_arg_interior_facet(_, self): - return op2.DatParloopArg(self._mesh.interior_facets.local_facet_dat) + mesh = next(self._active_interior_facets) + if mesh is self._mesh: + m = None + else: + m, integral_type = mesh.topology.trans_mesh_entity_map(self._mesh.topology, self._integral_type, self._subdomain_id, self._all_integer_subdomain_ids) + assert integral_type == "interior_facet" + return op2.DatParloopArg(mesh.interior_facets.local_facet_dat, m) -@_as_parloop_arg.register(kernel_args.ExteriorFacetOrientationKernelArg) -def _as_parloop_arg_exterior_facet_orientation(_, self): - return op2.DatParloopArg(self._mesh.exterior_facets.local_facet_orientation_dat) +@_as_parloop_arg.register(kernel_args.OrientationsExteriorFacetKernelArg) +def _(_, self): + mesh = next(self._active_orientations_exterior_facet) + if mesh is self._mesh: + m = None + else: + m, integral_type = mesh.topology.trans_mesh_entity_map(self._mesh.topology, self._integral_type, self._subdomain_id, self._all_integer_subdomain_ids) + assert integral_type == "exterior_facet" + return op2.DatParloopArg(mesh.exterior_facets.local_facet_orientation_dat, m) -@_as_parloop_arg.register(kernel_args.InteriorFacetOrientationKernelArg) -def _as_parloop_arg_interior_facet_orientation(_, self): - return op2.DatParloopArg(self._mesh.interior_facets.local_facet_orientation_dat) +@_as_parloop_arg.register(kernel_args.OrientationsInteriorFacetKernelArg) +def _(_, self): + mesh = next(self._active_orientations_interior_facet) + if mesh is self._mesh: + m = None + else: + m, integral_type = mesh.topology.trans_mesh_entity_map(self._mesh.topology, self._integral_type, self._subdomain_id, self._all_integer_subdomain_ids) + assert integral_type == "interior_facet" + return op2.DatParloopArg(mesh.interior_facets.local_facet_orientation_dat, m) @_as_parloop_arg.register(CellFacetKernelArg) @@ -2134,6 +2214,27 @@ def _as_parloop_arg_layer_count(_, self): class _FormHandler: """Utility class for inspecting forms and local kernels.""" + @staticmethod + def iter_active_coordinates(form, kinfo): + """Yield the form coordinates referenced in ``kinfo``.""" + all_meshes = extract_domains(form) + for i in kinfo.active_domain_numbers.coordinates: + yield all_meshes[i].coordinates + + @staticmethod + def iter_active_cell_orientations(form, kinfo): + """Yield the form cell orientations referenced in ``kinfo``.""" + all_meshes = extract_domains(form) + for i in kinfo.active_domain_numbers.cell_orientations: + yield all_meshes[i].cell_orientations() + + @staticmethod + def iter_active_cell_sizes(form, kinfo): + """Yield the form cell sizes referenced in ``kinfo``.""" + all_meshes = extract_domains(form) + for i in kinfo.active_domain_numbers.cell_sizes: + yield all_meshes[i].cell_sizes + @staticmethod def iter_active_coefficients(form, kinfo): """Yield the form coefficients referenced in ``kinfo``.""" @@ -2152,6 +2253,38 @@ def iter_constants(form, kinfo): for constant_index in kinfo.constant_numbers: yield all_constants[constant_index] + @staticmethod + def iter_active_exterior_facets(form, kinfo): + """Yield the form exterior facets referenced in ``kinfo``.""" + all_meshes = extract_domains(form) + for i in kinfo.active_domain_numbers.exterior_facets: + mesh = all_meshes[i] + yield mesh + + @staticmethod + def iter_active_interior_facets(form, kinfo): + """Yield the form interior facets referenced in ``kinfo``.""" + all_meshes = extract_domains(form) + for i in kinfo.active_domain_numbers.interior_facets: + mesh = all_meshes[i] + yield mesh + + @staticmethod + def iter_active_orientations_exterior_facet(form, kinfo): + """Yield the form exterior facet orientations referenced in ``kinfo``.""" + all_meshes = extract_domains(form) + for i in kinfo.active_domain_numbers.orientations_exterior_facet: + mesh = all_meshes[i] + yield mesh + + @staticmethod + def iter_active_orientations_interior_facet(form, kinfo): + """Yield the form interior facet orientations referenced in ``kinfo``.""" + all_meshes = extract_domains(form) + for i in kinfo.active_domain_numbers.orientations_interior_facet: + mesh = all_meshes[i] + yield mesh + @staticmethod def index_function_spaces(form, indices): """Return the function spaces of the form's arguments, indexed diff --git a/firedrake/bcs.py b/firedrake/bcs.py index e1408da545..17c199c1d8 100644 --- a/firedrake/bcs.py +++ b/firedrake/bcs.py @@ -162,8 +162,6 @@ def hermite_stride(bcnodes): # take intersection of facet nodes, and add it to bcnodes # i, j, k can also be strings. bcnodes1 = [] - if len(s) > 1 and not isinstance(self._function_space.finat_element, (finat.Lagrange, finat.GaussLobattoLegendre)): - raise TypeError("Currently, edge conditions have only been tested with CG Lagrange elements") for ss in s: # intersection of facets # Edge conditions have only been tested with Lagrange elements. diff --git a/firedrake/checkpointing.py b/firedrake/checkpointing.py index 6a63b1aaf1..c8bb4d7cf4 100644 --- a/firedrake/checkpointing.py +++ b/firedrake/checkpointing.py @@ -566,6 +566,8 @@ def save_mesh(self, mesh, distribution_name=None, permutation_name=None): :kwarg distribution_name: the name under which distribution is saved; if `None`, auto-generated name will be used. :kwarg permutation_name: the name under which permutation is saved; if `None`, auto-generated name will be used. """ + # TODO: Add general MeshSequence support. + mesh = mesh.unique() # Handle extruded mesh tmesh = mesh.topology if mesh.extruded: @@ -835,6 +837,8 @@ def get_timestepping_history(self, mesh, name): @PETSc.Log.EventDecorator("SaveFunctionSpace") def _save_function_space(self, V): mesh = V.mesh() + # TODO: Add general MeshSequence support. + mesh = mesh.unique() if isinstance(V.topological, impl.MixedFunctionSpace): V_name = self._generate_function_space_name(V) base_path = self._path_to_mixed_function_space(mesh.name, V_name) @@ -910,10 +914,12 @@ def save_function(self, f, idx=None, name=None, timestepping_info={}): each index. """ V = f.function_space() - mesh = V.mesh() if name: g = Function(V, val=f.dat, name=name) return self.save_function(g, idx=idx, timestepping_info=timestepping_info) + mesh = V.mesh() + # TODO: Add general MeshSequence support. + mesh = mesh.unique() # -- Save function space -- self._save_function_space(V) # -- Save function -- @@ -1224,6 +1230,8 @@ def _load_mesh_topology(self, tmesh_name, reorder, distribution_parameters): @PETSc.Log.EventDecorator("LoadFunctionSpace") def _load_function_space(self, mesh, name): + # TODO: Add general MeshSequence support. + mesh = mesh.unique() mesh_key = self._generate_mesh_key_from_names(mesh.name, mesh.topology._distribution_name, mesh.topology._permutation_name) @@ -1299,6 +1307,8 @@ def load_function(self, mesh, name, idx=None): be loaded with idx only when it was saved with idx. :returns: the loaded :class:`~.Function`. """ + # TODO: Add general MeshSequence support. + mesh = mesh.unique() tmesh = mesh.topology if name in self._get_mixed_function_name_mixed_function_space_name_map(mesh.name): V_name = self._get_mixed_function_name_mixed_function_space_name_map(mesh.name)[name] diff --git a/firedrake/dmhooks.py b/firedrake/dmhooks.py index 046852b2e6..d39f1465c6 100644 --- a/firedrake/dmhooks.py +++ b/firedrake/dmhooks.py @@ -43,6 +43,7 @@ import firedrake from firedrake.petsc import PETSc +from firedrake.mesh import MeshSequenceGeometry @PETSc.Log.EventDecorator() @@ -53,8 +54,11 @@ def get_function_space(dm): :raises RuntimeError: if no function space was found. """ info = dm.getAttr("__fs_info__") - meshref, element, indices, (name, names), boundary_sets = info - mesh = meshref() + meshref_tuple, element, indices, (name, names), boundary_sets = info + if len(meshref_tuple) == 1: + mesh = meshref_tuple[0]() + else: + mesh = MeshSequenceGeometry([meshref() for meshref in meshref_tuple]) if mesh is None: raise RuntimeError("Somehow your mesh was collected, this should never happen") V = firedrake.FunctionSpace(mesh, element, name=name) @@ -80,8 +84,6 @@ def set_function_space(dm, V): This stores the information necessary to make a function space given a DM. """ - mesh = V.mesh() - indices = [] names = [] while V.parent is not None: @@ -92,11 +94,12 @@ def set_function_space(dm, V): assert V.index is None indices.append(V.component) V = V.parent + mesh = V.mesh() if len(V) > 1: names = tuple(V_.name for V_ in V) element = V.ufl_element() boundary_sets = tuple(V_.boundary_set for V_ in V) - info = (weakref.ref(mesh), element, tuple(reversed(indices)), (V.name, names), boundary_sets) + info = (tuple(weakref.ref(m) for m in mesh), element, tuple(reversed(indices)), (V.name, names), boundary_sets) dm.setAttr("__fs_info__", info) @@ -414,7 +417,9 @@ def coarsen(dm, comm): """ from firedrake.mg.utils import get_level V = get_function_space(dm) - hierarchy, level = get_level(V.mesh()) + # TODO: Think harder. + m, = set(m_ for m_ in V.mesh()) + hierarchy, level = get_level(m) if level < 1: raise RuntimeError("Cannot coarsen coarsest DM") coarsen = get_ctx_coarsener(dm) diff --git a/firedrake/ensemble/ensemble_functionspace.py b/firedrake/ensemble/ensemble_functionspace.py index 7a0c7b577c..b846949258 100644 --- a/firedrake/ensemble/ensemble_functionspace.py +++ b/firedrake/ensemble/ensemble_functionspace.py @@ -92,7 +92,7 @@ class EnsembleFunctionSpaceBase: - Dual ensemble objects: :class:`EnsembleDualSpace` and :class:`~firedrake.ensemble.ensemble_function.EnsembleCofunction`. """ def __init__(self, local_spaces: Collection, ensemble: Ensemble): - meshes = set(V.mesh() for V in local_spaces) + meshes = set(V.mesh().unique() for V in local_spaces) nlocal_meshes = len(meshes) max_local_meshes = ensemble.ensemble_comm.allreduce(nlocal_meshes, MPI.MAX) if max_local_meshes > 1: diff --git a/firedrake/function.py b/firedrake/function.py index a628ac6599..b2cda5bc4e 100644 --- a/firedrake/function.py +++ b/firedrake/function.py @@ -593,16 +593,20 @@ def _at(self, arg, *args, **kwargs): tolerance = kwargs.get('tolerance', None) mesh = self.function_space().mesh() + if len(set(mesh)) == 1: + mesh_unique = mesh.unique() + else: + raise NotImplementedError("Not implemented for general mixed meshes") if tolerance is None: - tolerance = mesh.tolerance + tolerance = mesh_unique.tolerance else: - mesh.tolerance = tolerance + mesh_unique.tolerance = tolerance # Handle f._at(0.3) if not arg.shape: arg = arg.reshape(-1) - if mesh.variable_layers: + if mesh_unique.variable_layers: raise NotImplementedError("Point evaluation not implemented for variable layers") # Validate geometric dimension @@ -778,7 +782,7 @@ def evaluate(self, function: Function) -> np.ndarray | Tuple[np.ndarray, ...]: if function.function_space().ufl_element().family() == "Real": return function.dat.data_ro - function_mesh = function.function_space().mesh() + function_mesh = function.function_space().mesh().unique() if function_mesh is not self.mesh: raise ValueError("Function mesh must be the same Mesh object as the PointEvaluator mesh.") if coord_changed := function_mesh.coordinates.dat.dat_version != self.mesh._saved_coordinate_dat_version: diff --git a/firedrake/functionspace.py b/firedrake/functionspace.py index cefa3bf9a4..2cca9f872b 100644 --- a/firedrake/functionspace.py +++ b/firedrake/functionspace.py @@ -4,6 +4,7 @@ API is functional, rather than object-based, to allow for simple backwards-compatibility, argument checking, and dispatch. """ +import itertools import ufl import finat.ufl @@ -253,6 +254,8 @@ def MixedFunctionSpace(spaces, name=None, mesh=None): :class:`finat.ufl.mixedelement.MixedElement`, ignored otherwise. """ + from firedrake.mesh import MeshSequenceGeometry + if isinstance(spaces, finat.ufl.FiniteElementBase): # Build the spaces if we got a mixed element assert type(spaces) is finat.ufl.MixedElement and mesh is not None @@ -267,13 +270,8 @@ def rec(eles): sub_elements.append(ele) rec(spaces.sub_elements) spaces = [FunctionSpace(mesh, element) for element in sub_elements] - - # Check that function spaces are on the same mesh - meshes = [space.mesh() for space in spaces] - for i in range(1, len(meshes)): - if meshes[i] is not meshes[0]: - raise ValueError("All function spaces must be defined on the same mesh!") - + # Flatten MeshSequences. + meshes = list(itertools.chain(*[space.mesh() for space in spaces])) try: cls, = set(type(s) for s in spaces) except ValueError: @@ -281,8 +279,6 @@ def rec(eles): # We had not implemented something in between, so let's make it primal cls = impl.WithGeometry - # Select mesh - mesh = meshes[0] # Get topological spaces spaces = tuple(s.topological for s in flatten(spaces)) # Error checking @@ -296,10 +292,9 @@ def rec(eles): else: raise ValueError("Can't make mixed space with %s" % type(space)) - new = impl.MixedFunctionSpace(spaces, name=name) - if mesh is not mesh.topology: - new = cls.create(new, mesh) - return new + mixed_mesh_geometry = MeshSequenceGeometry(meshes) + new = impl.MixedFunctionSpace(spaces, mixed_mesh_geometry.topology, name=name) + return cls.create(new, mixed_mesh_geometry) @PETSc.Log.EventDecorator("CreateFunctionSpace") diff --git a/firedrake/functionspaceimpl.py b/firedrake/functionspaceimpl.py index 4dda3266e1..f5781d1d10 100644 --- a/firedrake/functionspaceimpl.py +++ b/firedrake/functionspaceimpl.py @@ -18,8 +18,8 @@ from pyop2.utils import as_tuple from firedrake import dmhooks, utils +from firedrake.mesh import MeshGeometry, MeshSequenceTopology, MeshSequenceGeometry from firedrake.functionspacedata import get_shared_data, create_element -from firedrake.mesh import MeshGeometry from firedrake.petsc import PETSc @@ -95,9 +95,11 @@ class WithGeometryBase(object): generation. """ def __init__(self, mesh, element, component=None, cargo=None): + if type(element) is finat.ufl.MixedElement: + if not isinstance(mesh, MeshSequenceGeometry): + raise TypeError(f"Can only use MixedElement with MeshSequenceGeometry: got {type(mesh)}") assert component is None or isinstance(component, int) assert cargo is None or isinstance(cargo, FunctionSpaceCargo) - super().__init__(mesh, element, label=cargo.topological._label or "") self.component = component self.cargo = cargo @@ -105,16 +107,25 @@ def __init__(self, mesh, element, component=None, cargo=None): self._comm = mpi.internal_comm(mesh.comm, self) @classmethod - def create(cls, function_space, mesh): + def create(cls, function_space, mesh, parent=None): """Create a :class:`WithGeometry`. - :arg function_space: The topological function space to attach - geometry to. - :arg mesh: The mesh with geometric information to use. + Parameters + ---------- + function_space : FunctionSpace or MixedFunctionSpace + Topological function space to attach geometry to. + mesh : MeshGeometry + Mesh with geometric information to use. + parent : WithGeometry + Parent geometric function space if exists. + """ + if isinstance(function_space, MixedFunctionSpace): + if not isinstance(mesh, MeshSequenceGeometry): + raise TypeError(f"Can only use MixedFunctionSpace with MeshSequenceGeometry: got {type(mesh)}") function_space = function_space.topological - assert mesh.topology is function_space.mesh() - assert mesh.topology is not mesh + assert mesh.topology == function_space.mesh() + assert mesh.topology != mesh element = function_space.ufl_element().reconstruct(cell=mesh.ufl_cell()) @@ -122,7 +133,8 @@ def create(cls, function_space, mesh): component = function_space.component if function_space.parent is not None: - parent = cls.create(function_space.parent, mesh) + if parent is None: + raise ValueError("Must pass parent if function_space.parent is not None") else: parent = None @@ -152,8 +164,13 @@ def topological(self, val): @utils.cached_property def subspaces(self): r"""Split into a tuple of constituent spaces.""" - return tuple(type(self).create(subspace, self.mesh()) - for subspace in self.topological.subspaces) + if isinstance(self.topological, MixedFunctionSpace): + return tuple( + type(self).create(subspace, mesh, parent=self) + for mesh, subspace in zip(self.mesh(), self.topological.subspaces, strict=True) + ) + else: + return (self, ) @property def subfunctions(self): @@ -178,11 +195,8 @@ def ufl_cell(self): @utils.cached_property def _components(self): - if len(self) == 1: - return tuple(type(self).create(self.topological.sub(i), self.mesh()) - for i in range(self.block_size)) - else: - return self.subspaces + return tuple(type(self).create(self.topological.sub(i), self.mesh(), parent=self) + for i in range(self.block_size)) @PETSc.Log.EventDecorator() def sub(self, i): @@ -301,7 +315,7 @@ def __eq__(self, other): return False try: return self.topological == other.topological and \ - self.mesh() is other.mesh() + self.mesh() == other.mesh() except AttributeError: return False @@ -363,9 +377,17 @@ def make_function_space(cls, mesh, element, name=None): topology = mesh.topology # Create a new abstract (Mixed/Real)FunctionSpace, these are neither primal nor dual. if type(element) is finat.ufl.MixedElement: - spaces = [cls.make_function_space(topology, e) for e in element.sub_elements] - new = MixedFunctionSpace(spaces, name=name) + if isinstance(mesh, MeshGeometry): + mesh = MeshSequenceGeometry([mesh for _ in element.sub_elements]) + topology = mesh.topology + else: + if not isinstance(mesh, MeshSequenceGeometry): + raise TypeError(f"mesh must be MeshSequenceGeometry: got {mesh}") + spaces = [cls.make_function_space(topo, e) for topo, e in zip(topology, element.sub_elements, strict=True)] + new = MixedFunctionSpace(spaces, topology, name=name) else: + if isinstance(mesh, MeshSequenceGeometry): + raise TypeError(f"mesh must not be MeshSequenceGeometry: got {mesh}") # Check that any Vector/Tensor/Mixed modifiers are outermost. check_element(element) if element.family() == "Real": @@ -450,7 +472,8 @@ def __init__(self, mesh, element, component=None, cargo=None): cargo=cargo) def dual(self): - return FiredrakeDualSpace.create(self.topological, self.mesh()) + parent = None if self.parent is None else self.parent.dual() + return FiredrakeDualSpace.create(self.topological, self.mesh(), parent=parent) class FiredrakeDualSpace(WithGeometryBase, ufl.functionspace.DualSpace): @@ -461,7 +484,8 @@ def __init__(self, mesh, element, component=None, cargo=None): cargo=cargo) def dual(self): - return WithGeometry.create(self.topological, self.mesh()) + parent = None if self.parent is None else self.parent.dual() + return WithGeometry.create(self.topological, self.mesh(), parent=parent) class FunctionSpace(object): @@ -591,7 +615,7 @@ def __eq__(self, other): if not isinstance(other, FunctionSpace): return False # FIXME: Think harder about equality - return self.mesh() is other.mesh() and \ + return self.mesh() == other.mesh() and \ self.dof_dset is other.dof_dset and \ self.ufl_element() == other.ufl_element() and \ self.component == other.component @@ -991,11 +1015,14 @@ class MixedFunctionSpace(object): but should instead use the functional interface provided by :func:`.MixedFunctionSpace`. """ - def __init__(self, spaces, name=None): + def __init__(self, spaces, mesh, name=None): super(MixedFunctionSpace, self).__init__() + if not isinstance(mesh, MeshSequenceTopology): + raise TypeError(f"mesh must be MeshSequenceTopology: got {mesh}") + if len(mesh) != len(spaces): + raise RuntimeError(f"len(mesh) ({len(mesh)}) != len(spaces) ({len(spaces)})") self._spaces = tuple(IndexedFunctionSpace(i, s, self) for i, s in enumerate(spaces)) - mesh, = set(s.mesh() for s in spaces) self._ufl_function_space = ufl.FunctionSpace(mesh.ufl_mesh(), finat.ufl.MixedElement(*[s.ufl_element() for s in spaces])) self.name = name or "_".join(str(s.name) for s in spaces) @@ -1199,7 +1226,9 @@ def dm(self): def _dm(self): from firedrake.mg.utils import get_level dm = self.dof_dset.dm - _, level = get_level(self.mesh()) + # TODO: Think harder. + m = self.mesh()[0] + _, level = get_level(m) dmhooks.attach_hooks(dm, level=level) return dm @@ -1376,7 +1405,7 @@ def __eq__(self, other): if not isinstance(other, RealFunctionSpace): return False # FIXME: Think harder about equality - return self.mesh() is other.mesh() and \ + return self.mesh() == other.mesh() and \ self.ufl_element() == other.ufl_element() def __ne__(self, other): diff --git a/firedrake/mesh.py b/firedrake/mesh.py index 87757a2e54..5c9f92fb02 100644 --- a/firedrake/mesh.py +++ b/firedrake/mesh.py @@ -56,6 +56,7 @@ 'DEFAULT_MESH_NAME', 'MeshGeometry', 'MeshTopology', 'AbstractMeshTopology', 'ExtrudedMeshTopology', 'VertexOnlyMeshTopology', 'VertexOnlyMeshMissingPointsError', + 'MeshSequenceGeometry', 'MeshSequenceTopology', 'Submesh' ] @@ -922,6 +923,12 @@ def mark_entities(self, tf, label_value, label_name=None): def extruded_periodic(self): return self.cell_set._extruded_periodic + def __iter__(self): + yield self + + def unique(self): + return self + # submesh @utils.cached_property @@ -2834,6 +2841,12 @@ def mark_entities(self, f, label_value, label_name=None): """ self.topology.mark_entities(f.topological, label_value, label_name) + def __iter__(self): + yield self + + def unique(self): + return self + @PETSc.Log.EventDecorator() def make_mesh_from_coordinates(coordinates, name, tolerance=0.5): @@ -4695,3 +4708,185 @@ def Submesh(mesh, subdim, subdomain_id, label_name=None, name=None): }, ) return submesh + + +class MeshSequenceGeometry(ufl.MeshSequence): + """A representation of mixed mesh geometry.""" + + def __init__(self, meshes, set_hierarchy=True): + """Initialise. + + Parameters + ---------- + meshes : tuple or list + `MeshGeometry`s to make `MeshSequenceGeometry` with. + set_hierarchy : bool + Flag for making hierarchy. + + """ + for m in meshes: + if not isinstance(m, MeshGeometry): + raise ValueError(f"Got {type(m)}") + super().__init__(meshes) + self.comm = meshes[0].comm + # Only set hierarchy at top level. + if set_hierarchy: + self.set_hierarchy() + + @utils.cached_property + def topology(self): + return MeshSequenceTopology([m.topology for m in self._meshes]) + + @property + def topological(self): + """Alias of topology. + + This is to ensure consistent naming for some multigrid codes.""" + return self.topology + + def __eq__(self, other): + if type(other) != type(self): + return False + if len(other) != len(self): + return False + for o, s in zip(other, self): + if o is not s: + return False + return True + + def __ne__(self, other): + return not self.__eq__(other) + + def __hash__(self): + return hash(self._meshes) + + def __len__(self): + return len(self._meshes) + + def __iter__(self): + return iter(self._meshes) + + def __getitem__(self, i): + return self._meshes[i] + + @utils.cached_property + def extruded(self): + m = self.unique() + return m.extruded + + def unique(self): + """Return a single component or raise exception.""" + if len(set(self._meshes)) > 1: + raise RuntimeError(f"Found multiple meshes in {self} where a single mesh is expected") + m, = set(self._meshes) + return m + + def set_hierarchy(self): + """Set mesh hierarchy if needed.""" + from firedrake.mg.utils import set_level, get_level, has_level + + # TODO: Think harder on how mesh hierarchy should work with mixed meshes. + if all(not has_level(m) for m in self._meshes): + return + else: + if not all(has_level(m) for m in self._meshes): + raise RuntimeError("Found inconsistent component meshes") + hierarchy_list = [] + level_list = [] + for m in self: + hierarchy, level = get_level(m) + hierarchy_list.append(hierarchy) + level_list.append(level) + nlevels, = set(len(hierarchy) for hierarchy in hierarchy_list) + level, = set(level_list) + result = [] + for ilevel in range(nlevels): + if ilevel == level: + result.append(self) + else: + result.append(MeshSequenceGeometry([hierarchy[ilevel] for hierarchy in hierarchy_list], set_hierarchy=False)) + result = tuple(result) + for i, m in enumerate(result): + set_level(m, result, i) + + @property + def _comm(self): + return self.topology._comm + + +class MeshSequenceTopology(object): + """A representation of mixed mesh topology.""" + + def __init__(self, meshes): + """Initialise. + + Parameters + ---------- + meshes : tuple or list + `MeshTopology`s to make `MeshSequenceTopology` with. + + """ + for m in meshes: + if not isinstance(m, AbstractMeshTopology): + raise ValueError(f"Got {type(m)}") + self._meshes = tuple(meshes) + self.comm = meshes[0].comm + self._comm = internal_comm(self.comm, self) + + @property + def topology(self): + """The underlying mesh topology object.""" + return self + + @property + def topological(self): + """Alias of topology. + + This is to ensure consistent naming for some multigrid codes.""" + return self + + def ufl_cell(self): + cell, = set(m.ufl_cell() for m in self._meshes) + return cell + + def ufl_mesh(self): + cell = self.ufl_cell() + return ufl.MeshSequence([ufl.Mesh(finat.ufl.VectorElement("Lagrange", cell, 1, dim=cell.topological_dimension())) + for _ in self._meshes]) + + def __eq__(self, other): + if type(other) != type(self): + return False + if len(other) != len(self): + return False + for o, s in zip(other, self): + if o is not s: + return False + return True + + def __ne__(self, other): + return not self.__eq__(other) + + def __hash__(self): + return hash(self._meshes) + + def __len__(self): + return len(self._meshes) + + def __iter__(self): + return iter(self._meshes) + + def __getitem__(self, i): + return self._meshes[i] + + @utils.cached_property + def extruded(self): + m = self.unique() + return m.extruded + + def unique(self): + """Return a single component or raise exception.""" + if len(set(self._meshes)) > 1: + raise RuntimeError(f"Found multiple meshes in {self} where a single mesh is expected") + m, = set(self._meshes) + return m diff --git a/firedrake/mg/kernels.py b/firedrake/mg/kernels.py index 09436bdbbd..6a648689dd 100644 --- a/firedrake/mg/kernels.py +++ b/firedrake/mg/kernels.py @@ -135,6 +135,7 @@ def compile_element(expression, dual_space=None, parameters=None, # Replace coordinates (if any) builder = firedrake_interface.KernelBuilderBase(scalar_type=ScalarType) domain = extract_unique_domain(expression) + builder._domain_integral_type_map = {domain: "cell"} # Translate to GEM cell = domain.ufl_cell() dim = cell.topological_dimension() @@ -143,7 +144,6 @@ def compile_element(expression, dual_space=None, parameters=None, config = dict(interface=builder, ufl_cell=cell, - integral_type="cell", point_indices=(), point_expr=point, argument_multiindices=argument_multiindices, @@ -519,6 +519,7 @@ def dg_injection_kernel(Vf, Vc, ncell): if complex_mode: raise NotImplementedError("In complex mode we are waiting for Slate") macro_builder = MacroKernelBuilder(ScalarType, ncell) + macro_builder._domain_integral_type_map = {Vf.mesh(): "cell"} f = ufl.Coefficient(Vf) macro_builder.set_coefficients([f]) macro_builder.set_coordinates(Vf.mesh()) @@ -536,7 +537,6 @@ def dg_injection_kernel(Vf, Vc, ncell): integration_dim, entity_ids = lower_integral_type(Vfe.cell, "cell") macro_cfg = dict(interface=macro_builder, ufl_cell=Vf.ufl_cell(), - integral_type="cell", integration_dim=integration_dim, entity_ids=entity_ids, index_cache=index_cache, @@ -557,13 +557,14 @@ def dg_injection_kernel(Vf, Vc, ncell): integral_type="cell", subdomain_id=("otherwise",), domain_number=0, + domain_integral_type_map={Vc.mesh(): "cell"}, arguments=(ufl.TestFunction(Vc), ), coefficients=(), coefficient_split={}, coefficient_numbers=()) coarse_builder = firedrake_interface.KernelBuilder(info, parameters["scalar_type"]) - coarse_builder.set_coordinates(Vc.mesh()) + coarse_builder.set_coordinates([Vc.mesh()]) argument_multiindices = coarse_builder.argument_multiindices argument_multiindex, = argument_multiindices return_variable, = coarse_builder.return_variables @@ -574,7 +575,6 @@ def dg_injection_kernel(Vf, Vc, ncell): coarse_cfg = dict(interface=coarse_builder, ufl_cell=Vc.ufl_cell(), - integral_type="cell", integration_dim=integration_dim, entity_ids=entity_ids, index_cache=index_cache, diff --git a/firedrake/mg/ufl_utils.py b/firedrake/mg/ufl_utils.py index 6a02b8229e..44a5e22126 100644 --- a/firedrake/mg/ufl_utils.py +++ b/firedrake/mg/ufl_utils.py @@ -65,6 +65,7 @@ def coarsen(expr, self, coefficient_mapping=None): @coarsen.register(ufl.Mesh) +@coarsen.register(ufl.MeshSequence) def coarsen_mesh(mesh, self, coefficient_mapping=None): hierarchy, level = utils.get_level(mesh) if hierarchy is None: @@ -148,7 +149,9 @@ def coarsen_function_space(V, self, coefficient_mapping=None): return V._coarse V_fine = V - mesh_coarse = self(V_fine.mesh(), self) + # Handle MixedFunctionSpace : V_fine.reconstruct requires MeshSequence. + fine_mesh = V_fine.mesh() if V_fine.index is None else V_fine.parent.mesh() + mesh_coarse = self(fine_mesh, self) name = f"coarse_{V.name}" if V.name else None V_coarse = V_fine.reconstruct(mesh=mesh_coarse, name=name) V_coarse._fine = V_fine diff --git a/firedrake/pointeval_utils.py b/firedrake/pointeval_utils.py index 2c14656329..a1ed8db2a9 100644 --- a/firedrake/pointeval_utils.py +++ b/firedrake/pointeval_utils.py @@ -54,6 +54,7 @@ def compile_element(expression, coordinates, parameters=None): # Initialise kernel builder builder = firedrake_interface.KernelBuilderBase(utils.ScalarType) + builder._domain_integral_type_map = {domain: "cell"} builder.domain_coordinate[domain] = coordinates builder._coefficient(coordinates, "x") x_arg = builder.generate_arg_from_expression(builder.coefficient_map[coordinates]) @@ -71,7 +72,6 @@ def compile_element(expression, coordinates, parameters=None): config = dict(interface=builder, ufl_cell=extract_unique_domain(coordinates).ufl_cell(), - integral_type="cell", point_indices=(), point_expr=point, scalar_type=utils.ScalarType) diff --git a/firedrake/pointquery_utils.py b/firedrake/pointquery_utils.py index c425356267..0bf3a36c83 100644 --- a/firedrake/pointquery_utils.py +++ b/firedrake/pointquery_utils.py @@ -143,8 +143,8 @@ def to_reference_coords_newton_step(ufl_coordinate_element, parameters, x0_dtype expr = ufl_utils.simplify_abs(expr, complex_mode) builder = firedrake_interface.KernelBuilderBase(ScalarType) + builder._domain_integral_type_map = {domain: "cell"} builder.domain_coordinate[domain] = C - Cexpr = builder._coefficient(C, "C") x0_expr = builder._coefficient(x0, "x0") loopy_args = [ @@ -162,7 +162,6 @@ def to_reference_coords_newton_step(ufl_coordinate_element, parameters, x0_dtype context = tsfc.fem.GemPointContext( interface=builder, ufl_cell=cell, - integral_type="cell", point_indices=(), point_expr=point, scalar_type=parameters["scalar_type"] diff --git a/firedrake/preconditioners/asm.py b/firedrake/preconditioners/asm.py index 7bff0e35de..8be5f07487 100644 --- a/firedrake/preconditioners/asm.py +++ b/firedrake/preconditioners/asm.py @@ -152,8 +152,12 @@ class ASMStarPC(ASMPatchPC): def get_patches(self, V): mesh = V._mesh - mesh_dm = mesh.topology_dm - if mesh.cell_set._extruded: + if len(set(mesh)) == 1: + mesh_unique = mesh.unique() + else: + raise NotImplementedError("Not implemented for general mixed meshes") + mesh_dm = mesh_unique.topology_dm + if mesh_unique.cell_set._extruded: warning("applying ASMStarPC on an extruded mesh") # Obtain the topological entities to use to construct the stars @@ -207,8 +211,12 @@ class ASMVankaPC(ASMPatchPC): def get_patches(self, V): mesh = V._mesh - mesh_dm = mesh.topology_dm - if mesh.layers: + if len(set(mesh)) == 1: + mesh_unique = mesh.unique() + else: + raise NotImplementedError("Not implemented for general mixed meshes") + mesh_dm = mesh_unique.topology_dm + if mesh_unique.layers: warning("applying ASMVankaPC on an extruded mesh") # Obtain the topological entities to use to construct the stars @@ -296,8 +304,12 @@ class ASMLinesmoothPC(ASMPatchPC): def get_patches(self, V): mesh = V._mesh - assert mesh.cell_set._extruded - dm = mesh.topology_dm + if len(set(mesh)) == 1: + mesh_unique = mesh.unique() + else: + raise NotImplementedError("Not implemented for general mixed meshes") + assert mesh_unique.cell_set._extruded + dm = mesh_unique.topology_dm section = V.dm.getDefaultSection() # Obtain the codimensions to loop over from options, if present opts = PETSc.Options(self.prefix) @@ -402,9 +414,13 @@ class ASMExtrudedStarPC(ASMStarPC): def get_patches(self, V): mesh = V.mesh() - mesh_dm = mesh.topology_dm - nlayers = mesh.layers - if not mesh.cell_set._extruded: + if len(set(mesh)) == 1: + mesh_unique = mesh.unique() + else: + raise NotImplementedError("Not implemented for general mixed meshes") + mesh_dm = mesh_unique.topology_dm + nlayers = mesh_unique.layers + if not mesh_unique.cell_set._extruded: return super(ASMExtrudedStarPC, self).get_patches(V) periodic = mesh.extruded_periodic diff --git a/firedrake/preconditioners/fdm.py b/firedrake/preconditioners/fdm.py index 24cf2c4e94..752a037bbb 100644 --- a/firedrake/preconditioners/fdm.py +++ b/firedrake/preconditioners/fdm.py @@ -222,7 +222,7 @@ def allocate_matrix(self, Amat, V, J, bcs, fcp, pmat_type, use_static_condensati elif len(ifacet) == 1: Vfacet = V[ifacet[0]] ebig, = set(unrestrict_element(Vsub.ufl_element()) for Vsub in V) - Vbig = V.reconstruct(element=ebig) + Vbig = V.reconstruct(mesh=V.mesh().unique(), element=ebig) if len(V) > 1: dims = [Vsub.finat_element.space_dimension() for Vsub in V] assert sum(dims) == Vbig.finat_element.space_dimension() diff --git a/firedrake/preconditioners/patch.py b/firedrake/preconditioners/patch.py index c139c07f7a..bb47093a3d 100644 --- a/firedrake/preconditioners/patch.py +++ b/firedrake/preconditioners/patch.py @@ -5,6 +5,7 @@ from firedrake.utils import cached_property, complex_mode, IntType from firedrake.dmhooks import get_appctx, push_appctx, pop_appctx from firedrake.interpolation import Interpolate +from firedrake.ufl_expr import extract_domains from collections import namedtuple import operator @@ -135,6 +136,10 @@ def increment_dat_version(self): CompiledKernel = namedtuple('CompiledKernel', ["funptr", "kinfo"]) +def get_map(V, base_mesh, base_integral_type): + return V.topological.entity_node_map(base_mesh.topology, base_integral_type, None, None) + + def matrix_funptr(form, state): from firedrake.tsfc_interface import compile_form test, trial = map(operator.methodcaller("function_space"), form.arguments()) @@ -148,10 +153,13 @@ def matrix_funptr(form, state): kernels = compile_form(form, "subspace_form", split=False, dont_split=dont_split) + all_meshes = extract_domains(form) cell_kernels = [] int_facet_kernels = [] for kernel in kernels: kinfo = kernel.kinfo + mesh = all_meshes[kinfo.domain_number] # integration domain + integral_type = kinfo.integral_type if kinfo.subdomain_id != ("otherwise",): raise NotImplementedError("Only for full domain integrals") @@ -161,21 +169,17 @@ def matrix_funptr(form, state): # OK, now we've validated the kernel, let's build the callback args = [] - if kinfo.integral_type == "cell": - get_map = operator.methodcaller("cell_node_map") + if integral_type == "cell": kernels = cell_kernels - elif kinfo.integral_type == "interior_facet": - get_map = operator.methodcaller("interior_facet_node_map") + elif integral_type == "interior_facet": kernels = int_facet_kernels - else: - get_map = None toset = op2.Set(1, comm=test.comm) dofset = op2.DataSet(toset, 1) arity = sum(m.arity*s.cdim - for m, s in zip(get_map(test), + for m, s in zip(get_map(test, mesh, integral_type), test.dof_dset)) - iterset = get_map(test).iterset + iterset = get_map(test, mesh, integral_type).iterset entity_node_map = op2.Map(iterset, toset, arity, values=numpy.zeros(iterset.total_size*arity, dtype=IntType)) @@ -189,16 +193,17 @@ def matrix_funptr(form, state): values=numpy.zeros(iterset.total_size*arity, dtype=IntType)) statearg = statedat(op2.READ, state_entity_node_map) - mesh = form.ufl_domains()[kinfo.domain_number] - arg = mesh.coordinates.dat(op2.READ, get_map(mesh.coordinates)) - args.append(arg) - if kinfo.oriented: - c = mesh.cell_orientations() - arg = c.dat(op2.READ, get_map(c)) + for i in kinfo.active_domain_numbers.coordinates: + c = all_meshes[i].coordinates + arg = c.dat(op2.READ, get_map(c.function_space(), mesh, integral_type)) args.append(arg) - if kinfo.needs_cell_sizes: - c = mesh.cell_sizes - arg = c.dat(op2.READ, get_map(c)) + for i in kinfo.active_domain_numbers.cell_orientations: + c = all_meshes[i].cell_orientations() + arg = c.dat(op2.READ, get_map(c.function_space(), mesh, integral_type)) + args.append(arg) + for i in kinfo.active_domain_numbers.cell_sizes: + c = all_meshes[i].cell_sizes + arg = c.dat(op2.READ, get_map(c.function_space(), mesh, integral_type)) args.append(arg) for n, indices in kinfo.coefficient_numbers: c = form.coefficients()[n] @@ -209,7 +214,7 @@ def matrix_funptr(form, state): continue for ind in indices: c_ = c.subfunctions[ind] - map_ = get_map(c_) + map_ = get_map(c_.function_space(), mesh, integral_type) arg = c_.dat(op2.READ, map_) args.append(arg) @@ -217,7 +222,7 @@ def matrix_funptr(form, state): for constant_index in kinfo.constant_numbers: args.append(all_constants[constant_index].dat(op2.READ)) - if kinfo.integral_type == "interior_facet": + if integral_type == "interior_facet": arg = mesh.interior_facets.local_facet_dat(op2.READ) args.append(arg) iterset = op2.Subset(iterset, []) @@ -242,10 +247,13 @@ def residual_funptr(form, state): kernels = compile_form(form, "subspace_form", split=False, dont_split=dont_split) + all_meshes = extract_domains(form) cell_kernels = [] int_facet_kernels = [] for kernel in kernels: kinfo = kernel.kinfo + mesh = all_meshes[kinfo.domain_number] # integration domain + integral_type = kinfo.integral_type if kinfo.subdomain_id != ("otherwise",): raise NotImplementedError("Only for full domain integrals") @@ -254,20 +262,16 @@ def residual_funptr(form, state): args = [] if kinfo.integral_type == "cell": - get_map = operator.methodcaller("cell_node_map") kernels = cell_kernels elif kinfo.integral_type == "interior_facet": - get_map = operator.methodcaller("interior_facet_node_map") kernels = int_facet_kernels - else: - get_map = None toset = op2.Set(1, comm=test.comm) dofset = op2.DataSet(toset, 1) arity = sum(m.arity*s.cdim - for m, s in zip(get_map(test), + for m, s in zip(get_map(test, mesh, integral_type), test.dof_dset)) - iterset = get_map(test).iterset + iterset = get_map(test, mesh, integral_type).iterset entity_node_map = op2.Map(iterset, toset, arity, values=numpy.zeros(iterset.total_size*arity, dtype=IntType)) @@ -282,17 +286,17 @@ def residual_funptr(form, state): arg = dat(op2.INC, entity_node_map) args.append(arg) - mesh = form.ufl_domains()[kinfo.domain_number] - arg = mesh.coordinates.dat(op2.READ, get_map(mesh.coordinates)) - args.append(arg) - - if kinfo.oriented: - c = mesh.cell_orientations() - arg = c.dat(op2.READ, get_map(c)) + for i in kinfo.active_domain_numbers.coordinates: + c = all_meshes[i].coordinates + arg = c.dat(op2.READ, get_map(c.function_space(), mesh, integral_type)) + args.append(arg) + for i in kinfo.active_domain_numbers.cell_orientations: + c = all_meshes[i].cell_orientations() + arg = c.dat(op2.READ, get_map(c.function_space(), mesh, integral_type)) args.append(arg) - if kinfo.needs_cell_sizes: - c = mesh.cell_sizes - arg = c.dat(op2.READ, get_map(c)) + for i in kinfo.active_domain_numbers.cell_sizes: + c = all_meshes[i].cell_sizes + arg = c.dat(op2.READ, get_map(c.function_space(), mesh, integral_type)) args.append(arg) for n, indices in kinfo.coefficient_numbers: c = form.coefficients()[n] @@ -303,7 +307,7 @@ def residual_funptr(form, state): continue for ind in indices: c_ = c.subfunctions[ind] - map_ = get_map(c_) + map_ = get_map(c_.function_space(), mesh, integral_type) arg = c_.dat(op2.READ, map_) args.append(arg) @@ -516,14 +520,14 @@ def load_c_function(code, name, comm): return fn -def make_c_arguments(form, kernel, state, get_map, require_state=False, +def make_c_arguments(form, kernel, state, integral_type, require_state=False, require_facet_number=False): - mesh = form.ufl_domains()[kernel.kinfo.domain_number] - coeffs = [mesh.coordinates] - if kernel.kinfo.oriented: - coeffs.append(mesh.cell_orientations()) - if kernel.kinfo.needs_cell_sizes: - coeffs.append(mesh.cell_sizes) + all_meshes = extract_domains(form) + mesh = all_meshes[kernel.kinfo.domain_number] + coeffs = [] + coeffs.extend([all_meshes[i].coordinates for i in kernel.kinfo.active_domain_numbers.coordinates]) + coeffs.extend([all_meshes[i].cell_orientations() for i in kernel.kinfo.active_domain_numbers.cell_orientations]) + coeffs.extend([all_meshes[i].cell_sizes for i in kernel.kinfo.active_domain_numbers.cell_sizes]) for n, indices in kernel.kinfo.coefficient_numbers: c = form.coefficients()[n] if c is state: @@ -543,7 +547,7 @@ def make_c_arguments(form, kernel, state, get_map, require_state=False, map_args.append(None) else: data_args.extend(c.dat._kernel_args_) - map_ = get_map(c) + map_ = get_map(c.function_space(), mesh, integral_type) if map_ is not None: for k in map_._kernel_args_: if k not in seen: @@ -652,7 +656,11 @@ def sort_entities(self, dm, axis, dir, ndiv=None, divisions=None): raise RuntimeError("Must either set ndiv or divisions for PlaneSmoother!") mesh = dm.getAttr("__firedrake_mesh__") - coordinates = mesh.coordinates + if len(set(mesh)) == 1: + mesh_unique = mesh.unique() + else: + raise NotImplementedError("Not implemented for general mixed meshes") + coordinates = mesh_unique.coordinates V = coordinates.function_space() if V.finat_element.is_dg(): # We're using DG or DQ for our coordinates, so we got @@ -760,7 +768,11 @@ def initialize(self, obj): J, bcs = self.form(obj) V = J.arguments()[0].function_space() mesh = V.mesh() - self.plex = mesh.topology_dm + if len(set(mesh)) == 1: + mesh_unique = mesh.unique() + else: + raise NotImplementedError("Not implemented for general mixed meshes") + self.plex = mesh_unique.topology_dm # We need to attach the mesh and appctx to the plex, so that # PlaneSmoothers (and any other user-customised patch # constructors) can use firedrake's opinion of what @@ -769,10 +781,10 @@ def initialize(self, obj): self.ctx = ctx self.plex.setAttr("__firedrake_ctx__", weakref.proxy(ctx)) - if mesh.cell_set._extruded: + if mesh_unique.cell_set._extruded: raise NotImplementedError("Not implemented on extruded meshes") - if "overlap_type" not in mesh._distribution_parameters: + if "overlap_type" not in mesh_unique._distribution_parameters: if mesh.comm.size > 1: # Want to do # warnings.warn("You almost surely want to set an overlap_type in your mesh's distribution_parameters.") @@ -806,8 +818,7 @@ def initialize(self, obj): Jcell_kernels, Jint_facet_kernels = matrix_funptr(J, Jstate) Jcell_kernel, = Jcell_kernels Jcell_flops = Jcell_kernel.kinfo.kernel.num_flops - Jop_data_args, Jop_map_args = make_c_arguments(J, Jcell_kernel, Jstate, - operator.methodcaller("cell_node_map")) + Jop_data_args, Jop_map_args = make_c_arguments(J, Jcell_kernel, Jstate, "cell") code, Struct = make_jacobian_wrapper(Jop_data_args, Jop_map_args, Jcell_flops) Jop_function = load_c_function(code, "ComputeJacobian", mesh.comm) Jop_struct = make_c_struct(Jop_data_args, Jop_map_args, Jcell_kernel.funptr, Struct) @@ -818,11 +829,11 @@ def initialize(self, obj): Jhas_int_facet_kernel = True Jint_facet_flops = Jint_facet_kernel.kinfo.kernel.num_flops facet_Jop_data_args, facet_Jop_map_args = make_c_arguments(J, Jint_facet_kernel, Jstate, - operator.methodcaller("interior_facet_node_map"), + "interior_facet", require_facet_number=True) code, Struct = make_jacobian_wrapper(facet_Jop_data_args, facet_Jop_map_args, Jint_facet_flops) facet_Jop_function = load_c_function(code, "ComputeJacobian", mesh.comm) - point2facet = mesh.interior_facets.point2facetnumber.ctypes.data + point2facet = mesh_unique.interior_facets.point2facetnumber.ctypes.data facet_Jop_struct = make_c_struct(facet_Jop_data_args, facet_Jop_map_args, Jint_facet_kernel.funptr, Struct, point2facet=point2facet) @@ -836,7 +847,7 @@ def initialize(self, obj): Fcell_kernel, = Fcell_kernels Fcell_flops = Fcell_kernel.kinfo.kernel.num_flops Fop_data_args, Fop_map_args = make_c_arguments(F, Fcell_kernel, Fstate, - operator.methodcaller("cell_node_map"), + "cell", require_state=True) code, Struct = make_residual_wrapper(Fop_data_args, Fop_map_args, Fcell_flops) Fop_function = load_c_function(code, "ComputeResidual", mesh.comm) @@ -848,7 +859,7 @@ def initialize(self, obj): Fhas_int_facet_kernel = True Fint_facet_flops = Fint_facet_kernel.kinfo.kernel.num_flops facet_Fop_data_args, facet_Fop_map_args = make_c_arguments(F, Fint_facet_kernel, Fstate, - operator.methodcaller("interior_facet_node_map"), + "interior_facet", require_state=True, require_facet_number=True) code, Struct = make_jacobian_wrapper(facet_Fop_data_args, facet_Fop_map_args, Fint_facet_flops) @@ -859,7 +870,7 @@ def initialize(self, obj): point2facet=point2facet) patch.setDM(self.plex) - patch.setPatchCellNumbering(mesh._cell_numbering) + patch.setPatchCellNumbering(mesh_unique._cell_numbering) offsets = numpy.append([0], numpy.cumsum([W.dof_count for W in V])).astype(PETSc.IntType) diff --git a/firedrake/preconditioners/pmg.py b/firedrake/preconditioners/pmg.py index f4b45a67a5..5ed6d7b87b 100644 --- a/firedrake/preconditioners/pmg.py +++ b/firedrake/preconditioners/pmg.py @@ -1229,7 +1229,7 @@ def _weight(self): }} """ kernel = op2.Kernel(kernel_code, "weight", requires_zeroed_output_arguments=True) - op2.par_loop(kernel, weight.cell_set, weight.dat(op2.INC, weight.cell_node_map())) + op2.par_loop(kernel, weight.function_space().mesh().topology.unique().cell_set, weight.dat(op2.INC, weight.cell_node_map())) with weight.dat.vec as w: w.reciprocal() return weight @@ -1243,7 +1243,7 @@ def _kernels(self): uf_map = get_permuted_map(self.Vf) uc_map = get_permuted_map(self.Vc) prolong_kernel, restrict_kernel, coefficients = self.make_blas_kernels(self.Vf, self.Vc) - prolong_args = [prolong_kernel, self.uf.cell_set, + prolong_args = [prolong_kernel, self.uf.function_space().mesh().topology.unique().cell_set, self.uf.dat(op2.INC, uf_map), self.uc.dat(op2.READ, uc_map), self._weight.dat(op2.READ, uf_map)] @@ -1253,11 +1253,11 @@ def _kernels(self): uf_map = self.Vf.cell_node_map() uc_map = self.Vc.cell_node_map() prolong_kernel, restrict_kernel, coefficients = self.make_kernels(self.Vf, self.Vc) - prolong_args = [prolong_kernel, self.uf.cell_set, + prolong_args = [prolong_kernel, self.uf.function_space().mesh().topology.unique().cell_set, self.uf.dat(op2.WRITE, uf_map), self.uc.dat(op2.READ, uc_map)] - restrict_args = [restrict_kernel, self.uf.cell_set, + restrict_args = [restrict_kernel, self.uf.function_space().mesh().topology.unique().cell_set, self.uc.dat(op2.INC, uc_map), self.uf.dat(op2.READ, uf_map), self._weight.dat(op2.READ, uf_map)] @@ -1595,7 +1595,7 @@ def prolongation_matrix_aij(P1, Pk, P1_bcs=[], Pk_bcs=[]): lgmaps=((rlgmap, clgmap), ), unroll_map=unroll) expr = firedrake.TrialFunction(P1.sub(i)) kernel, coefficients = prolongation_transfer_kernel_action(Pk.sub(i), expr) - parloop_args = [kernel, mesh.cell_set, matarg] + parloop_args = [kernel, mesh.topology.unique().cell_set, matarg] for coefficient in coefficients: m_ = coefficient.cell_node_map() parloop_args.append(coefficient.dat(op2.READ, m_)) @@ -1612,7 +1612,7 @@ def prolongation_matrix_aij(P1, Pk, P1_bcs=[], Pk_bcs=[]): lgmaps=((rlgmap, clgmap), ), unroll_map=unroll) expr = firedrake.TrialFunction(P1) kernel, coefficients = prolongation_transfer_kernel_action(Pk, expr) - parloop_args = [kernel, mesh.cell_set, matarg] + parloop_args = [kernel, mesh.topology.unique().cell_set, matarg] for coefficient in coefficients: m_ = coefficient.cell_node_map() parloop_args.append(coefficient.dat(op2.READ, m_)) diff --git a/firedrake/slate/slac/compiler.py b/firedrake/slate/slac/compiler.py index 84c1cddc08..8751348b16 100644 --- a/firedrake/slate/slac/compiler.py +++ b/firedrake/slate/slac/compiler.py @@ -36,6 +36,7 @@ from gem import indices as make_indices from tsfc.kernel_args import OutputKernelArg, CoefficientKernelArg from tsfc.loopy import generate as generate_loopy +from tsfc.kernel_interface.firedrake_loopy import ActiveDomainNumbers import copy from petsc4py import PETSc @@ -192,14 +193,19 @@ def generate_loopy_kernel(slate_expr, compiler_parameters=None): kinfo = KernelInfo(kernel=loopykernel, integral_type="cell", # slate can only do things as contributions to the cell integrals - oriented=builder.bag.needs_cell_orientations, subdomain_id=("otherwise",), domain_number=0, + active_domain_numbers=ActiveDomainNumbers(coordinates=(0, ) if builder.bag.needs_coordinates else (), + cell_orientations=(0, ) if builder.bag.needs_cell_orientations else (), + cell_sizes=(0, ) if builder.bag.needs_cell_sizes else (), + exterior_facets=(), + interior_facets=(), + orientations_exterior_facet=(), + orientations_interior_facet=(),), coefficient_numbers=coefficient_numbers, constant_numbers=constant_numbers, needs_cell_facets=builder.bag.needs_cell_facets, pass_layer_arg=builder.bag.needs_mesh_layers, - needs_cell_sizes=builder.bag.needs_cell_sizes, arguments=arguments, events=events) diff --git a/firedrake/slate/slac/kernel_builder.py b/firedrake/slate/slac/kernel_builder.py index 8cf27b5298..1a9e9172d9 100644 --- a/firedrake/slate/slac/kernel_builder.py +++ b/firedrake/slate/slac/kernel_builder.py @@ -132,13 +132,22 @@ def collect_tsfc_kernel_data(self, mesh, tsfc_coefficients, tsfc_constants, wrap that are coordinates, orientations, cell sizes and cofficients. """ - kernel_data = [(mesh.coordinates, self.coordinates_arg_name)] - - if kinfo.oriented: + kernel_data = [] + for coord_domain_number in kinfo.active_domain_numbers.coordinates: + if coord_domain_number != 0: + raise ValueError("Slate currently only supports single domain") + self.bag.needs_coordinates = True + kernel_data.append((mesh.coordinates, self.coordinates_arg_name)) + + for cell_orientation_domain_number in kinfo.active_domain_numbers.cell_orientations: + if cell_orientation_domain_number != 0: + raise ValueError("Slate currently only supports single domain") self.bag.needs_cell_orientations = True kernel_data.append((mesh.cell_orientations(), self.cell_orientations_arg_name)) - if kinfo.needs_cell_sizes: + for cell_size_domain_number in kinfo.active_domain_numbers.cell_sizes: + if cell_size_domain_number != 0: + raise ValueError("Slate currently only supports single domain") self.bag.needs_cell_sizes = True kernel_data.append((mesh.cell_sizes, self.cell_sizes_arg_name)) @@ -337,10 +346,11 @@ def generate_wrapper_kernel_args(self, tensor2temp): args = [] tmp_args = [] - coords_extent = self.extent(self.expression.ufl_domain().coordinates) - coords_loopy_arg = loopy.GlobalArg(self.coordinates_arg_name, shape=coords_extent, - dtype=self.tsfc_parameters["scalar_type"]) - args.append(kernel_args.CoordinatesKernelArg(coords_loopy_arg)) + if self.bag.needs_coordinates: + coords_extent = self.extent(self.expression.ufl_domain().coordinates) + coords_loopy_arg = loopy.GlobalArg(self.coordinates_arg_name, shape=coords_extent, + dtype=self.tsfc_parameters["scalar_type"]) + args.append(kernel_args.CoordinatesKernelArg(coords_loopy_arg)) if self.bag.needs_cell_orientations: ori_extent = self.extent(self.expression.ufl_domain().cell_orientations()) @@ -440,9 +450,15 @@ def generate_tsfc_calls(self, terminal, loopy_tensor): if subdomain_id != "otherwise": raise NotImplementedError("No subdomain markers for cells yet") elif self.is_integral_type(integral_type, "facet_integral"): - predicates, fidx, facet_arg = self.facet_integral_predicates(mesh, integral_type, kinfo, subdomain_id) - reads.append(facet_arg) - inames_dep.append(fidx[0].name) + if kinfo.active_domain_numbers._asdict()[{"exterior_facet": "exterior_facets", + "exterior_facet_vert": "exterior_facets", + "interior_facet": "interior_facets", + "interior_facet_vert": "interior_facets"}[kinfo.integral_type]] != (): + predicates, fidx, facet_arg = self.facet_integral_predicates(mesh, integral_type, kinfo, subdomain_id) + reads.append(facet_arg) + inames_dep.append(fidx[0].name) + else: + predicates = None elif self.is_integral_type(integral_type, "layer_integral"): predicates = self.layer_integral_predicates(slate_tensor, integral_type) else: @@ -469,6 +485,7 @@ def __init__(self, coeffs, constants): self.coefficients = coeffs self.constants = constants self.inames = OrderedDict() + self.needs_coordinates = False self.needs_cell_orientations = False self.needs_cell_sizes = False self.needs_cell_facets = False diff --git a/firedrake/slate/slate.py b/firedrake/slate/slate.py index 942faf2bd8..aa3b7b4b8c 100644 --- a/firedrake/slate/slate.py +++ b/firedrake/slate/slate.py @@ -15,7 +15,7 @@ functions to be executed within the Firedrake architecture. """ from abc import ABCMeta, abstractproperty, abstractmethod - +import functools from collections import OrderedDict, namedtuple, defaultdict from ufl import Constant @@ -253,6 +253,13 @@ def ufl_domain(self): raise ValueError("All integrals must share the same domain of integration.") return domain + @staticmethod + def _expand_mixed_meshes(func): + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + return sort_domains(join_domains(func(self, *args, **kwargs))) + return wrapper + @abstractmethod def ufl_domains(self): """Returns the integration domains of the integrals associated with @@ -487,6 +494,7 @@ def slate_coefficients(self): """Returns a tuple of coefficients associated with the tensor.""" return self.coefficients() + @TensorBase._expand_mixed_meshes def ufl_domains(self): """Returns the integration domains of the integrals associated with the tensor. @@ -562,6 +570,7 @@ def slate_coefficients(self): """Returns a BlockFunction in a tuple which carries all information to generate the right coefficients and maps.""" return (BlockFunction(self._function, self._indices, self._original_function),) + @TensorBase._expand_mixed_meshes def ufl_domains(self): """Returns the integration domains of the integrals associated with the tensor. """ @@ -720,6 +729,7 @@ def slate_coefficients(self): """Returns a tuple of coefficients associated with the tensor.""" return self.coefficients() + @TensorBase._expand_mixed_meshes def ufl_domains(self): """Returns the integration domains of the integrals associated with the tensor. @@ -815,6 +825,7 @@ def slate_coefficients(self): """Returns a tuple of coefficients associated with the tensor.""" return self.coefficients() + @TensorBase._expand_mixed_meshes def ufl_domains(self): """Returns the integration domains of the integrals associated with the tensor. @@ -918,6 +929,7 @@ def slate_coefficients(self): """Returns a tuple of coefficients associated with the tensor.""" return self.coefficients() + @TensorBase._expand_mixed_meshes def ufl_domains(self): """Returns the integration domains of the integrals associated with the tensor. @@ -976,6 +988,7 @@ def slate_coefficients(self): coeffs = [op.slate_coefficients() for op in self.operands] return tuple(OrderedDict.fromkeys(chain(*coeffs))) + @TensorBase._expand_mixed_meshes def ufl_domains(self): """Returns the integration domains of the integrals associated with the tensor. diff --git a/firedrake/slate/static_condensation/hybridization.py b/firedrake/slate/static_condensation/hybridization.py index bd0fc761f0..cfc72f2119 100644 --- a/firedrake/slate/static_condensation/hybridization.py +++ b/firedrake/slate/static_condensation/hybridization.py @@ -54,6 +54,10 @@ def initialize(self, pc): V = test.function_space() mesh = V.mesh() + if len(set(mesh)) == 1: + mesh_unique = mesh.unique() + else: + raise NotImplementedError("Not implemented for general mixed meshes") if len(V) != 2: raise ValueError("Expecting two function spaces.") @@ -83,7 +87,7 @@ def initialize(self, pc): except TypeError: tdegree = W.ufl_element().degree() - 1 - TraceSpace = FunctionSpace(mesh, "HDiv Trace", tdegree) + TraceSpace = FunctionSpace(mesh[self.vidx], "HDiv Trace", tdegree) # Break the function spaces and define fully discontinuous spaces broken_elements = finat.ufl.MixedElement([finat.ufl.BrokenElement(Vi.ufl_element()) for Vi in V]) @@ -122,10 +126,10 @@ def initialize(self, pc): trial: TrialFunction(V_d)} Atilde = Tensor(replace(self.ctx.a, arg_map)) gammar = TestFunction(TraceSpace) - n = ufl.FacetNormal(mesh) + n = ufl.FacetNormal(mesh_unique) sigma = TrialFunctions(V_d)[self.vidx] - if mesh.cell_set._extruded: + if mesh_unique.cell_set._extruded: Kform = (gammar('+') * ufl.jump(sigma, n=n) * ufl.dS_h + gammar('+') * ufl.jump(sigma, n=n) * ufl.dS_v) else: @@ -159,7 +163,7 @@ def initialize(self, pc): integrand = gammar * ufl.dot(sigma, n) measures = [] trace_subdomains = [] - if mesh.cell_set._extruded: + if mesh_unique.cell_set._extruded: ds = ufl.ds_v for subdomain in sorted(extruded_neumann_subdomains): measures.append({"top": ufl.ds_t, "bottom": ufl.ds_b}[subdomain]) @@ -170,7 +174,7 @@ def initialize(self, pc): measures.append(ds) else: measures.extend((ds(sd) for sd in sorted(neumann_subdomains))) - markers = [int(x) for x in mesh.exterior_facets.unique_markers] + markers = [int(x) for x in mesh_unique.exterior_facets.unique_markers] dirichlet_subdomains = set(markers) - neumann_subdomains trace_subdomains.extend(sorted(dirichlet_subdomains)) @@ -185,7 +189,7 @@ def initialize(self, pc): # the exterior boundary. Extruded cells will have both # horizontal and vertical facets trace_subdomains = ["on_boundary"] - if mesh.cell_set._extruded: + if mesh_unique.cell_set._extruded: trace_subdomains.extend(["bottom", "top"]) trace_bcs = [DirichletBC(TraceSpace, 0, subdomain) for subdomain in trace_subdomains] diff --git a/firedrake/slate/static_condensation/scpc.py b/firedrake/slate/static_condensation/scpc.py index e9ccb022b3..22de4891b7 100644 --- a/firedrake/slate/static_condensation/scpc.py +++ b/firedrake/slate/static_condensation/scpc.py @@ -62,7 +62,7 @@ def initialize(self, pc): # Need to duplicate a space which is NOT # associated with a subspace of a mixed space. - Vc = FunctionSpace(W.mesh(), W[c_field].ufl_element()) + Vc = FunctionSpace(W.mesh()[c_field], W[c_field].ufl_element()) bcs = [] cxt_bcs = self.cxt.row_bcs for bc in cxt_bcs: diff --git a/firedrake/tsfc_interface.py b/firedrake/tsfc_interface.py index e7d5904d1c..cde7f678a1 100644 --- a/firedrake/tsfc_interface.py +++ b/firedrake/tsfc_interface.py @@ -11,11 +11,12 @@ import ufl import finat.ufl from ufl import conj, Form, ZeroBaseForm -from .ufl_expr import TestFunction +from .ufl_expr import TestFunction, extract_domains from tsfc import compile_form as original_tsfc_compile_form from tsfc.parameters import PARAMETERS as tsfc_default_parameters from tsfc.ufl_utils import extract_firedrake_constants +from tsfc.kernel_interface.firedrake_loopy import ActiveDomainNumbers from pyop2 import op2 from pyop2.caching import memory_and_disk_cache, default_parallel_hashkey @@ -34,14 +35,13 @@ KernelInfo = collections.namedtuple("KernelInfo", ["kernel", "integral_type", - "oriented", "subdomain_id", "domain_number", + "active_domain_numbers", "coefficient_numbers", "constant_numbers", "needs_cell_facets", "pass_layer_arg", - "needs_cell_sizes", "arguments", "events"]) @@ -80,6 +80,7 @@ def __init__( form, name, parameters, + domain_number_map, coefficient_numbers, constant_numbers, dont_split_numbers, @@ -95,6 +96,8 @@ def __init__( A prefix to be applied to the compiled kernel names. This is primarily useful for debugging. parameters : dict A dict of parameters to pass to the form compiler. + domain_number_map : dict + Map from domain numbers in the provided (split) form to domain numbers in the original form. coefficient_numbers : dict Map from coefficient numbers in the provided (split) form to coefficient numbers in the original form. constant_numbers : dict @@ -110,6 +113,8 @@ def __init__( diagonal=diagonal) kernels = [] for kernel in tree: + domain_number = domain_number_map[kernel.domain_number] + active_domain_numbers = ActiveDomainNumbers(*(tuple(domain_number_map[dn] for dn in dn_tuple) for dn_tuple in kernel.active_domain_numbers)) # Individual kernels do not have to use all of the coefficients # provided by the (split) form. Here we combine the numberings # of (kernel coefficients -> split form coefficients) and @@ -130,14 +135,13 @@ def __init__( events=events) kernels.append(KernelInfo(kernel=pyop2_kernel, integral_type=kernel.integral_type, - oriented=kernel.oriented, subdomain_id=kernel.subdomain_id, - domain_number=kernel.domain_number, + domain_number=domain_number, + active_domain_numbers=active_domain_numbers, coefficient_numbers=coefficient_numbers_per_kernel, constant_numbers=constant_numbers_per_kernel, needs_cell_facets=False, pass_layer_arg=False, - needs_cell_sizes=kernel.needs_cell_sizes, arguments=kernel.arguments, events=events)) self.kernels = tuple(kernels) @@ -216,6 +220,7 @@ def compile_form(form, name, parameters=None, split=True, dont_split=(), diagona kernels = [] numbering = form.terminal_numbering() + all_meshes = extract_domains(form) if split: iterable = split_form(form, diagonal=diagonal) else: @@ -231,8 +236,10 @@ def compile_form(form, name, parameters=None, split=True, dont_split=(), diagona # and that component doesn't actually appear in the form then we # have an empty form, which we should not attempt to assemble. continue - # Map local coefficient/constant numbers (as seen inside the + # Map local domain/coefficient/constant numbers (as seen inside the # compiler) to the global coefficient/constant numbers + meshes = extract_domains(f) + domain_number_map = tuple(all_meshes.index(m) for m in meshes) coefficient_numbers = tuple( numbering[c] for c in f.coefficients() ) @@ -245,6 +252,7 @@ def compile_form(form, name, parameters=None, split=True, dont_split=(), diagona f, prefix, parameters, + domain_number_map, coefficient_numbers, constant_numbers, dont_split_numbers, @@ -291,20 +299,21 @@ def _ensure_cachedir(comm=None): def gather_integer_subdomain_ids(knls): - """Gather a dict of all integer subdomain IDs per integral type. + """Gather a dict of all integer subdomain IDs per integral type per domain. This is needed to correctly interpret the ``"otherwise"`` subdomain ID. :arg knls: Iterable of :class:`SplitKernel` objects. """ - all_integer_subdomain_ids = collections.defaultdict(list) + all_integer_subdomain_ids = collections.defaultdict(lambda: collections.defaultdict(set)) for _, kinfo in knls: for subdomain_id in kinfo.subdomain_id: if subdomain_id != "otherwise": - all_integer_subdomain_ids[kinfo.integral_type].append(subdomain_id) + all_integer_subdomain_ids[kinfo.domain_number][kinfo.integral_type].add(subdomain_id) - for k, v in all_integer_subdomain_ids.items(): - all_integer_subdomain_ids[k] = tuple(sorted(v)) + for domain_number, integral_type_subdomain_ids_dict in all_integer_subdomain_ids.items(): + for integral_type, subdomain_ids in integral_type_subdomain_ids_dict.items(): + all_integer_subdomain_ids[domain_number][integral_type] = tuple(sorted(subdomain_ids)) return all_integer_subdomain_ids diff --git a/firedrake/ufl_expr.py b/firedrake/ufl_expr.py index e4fc89b175..fabe221ee5 100644 --- a/firedrake/ufl_expr.py +++ b/firedrake/ufl_expr.py @@ -5,7 +5,6 @@ from ufl.split_functions import split from ufl.algorithms import extract_arguments, extract_coefficients from ufl.domain import as_domain - import firedrake from firedrake import utils, function, cofunction from firedrake.constant import Constant @@ -233,63 +232,64 @@ def derivative(form, u, du=None, coefficient_derivatives=None): raise TypeError( f"Cannot take the derivative of a {type(form).__name__}" ) - u_is_x = isinstance(u, ufl.SpatialCoordinate) - if u_is_x or isinstance(u, (Constant, BaseFormOperator)): - uc = u - else: - uc, = extract_coefficients(u) - if not (u_is_x or isinstance(u, BaseFormOperator)) and len(uc.subfunctions) > 1 and set(extract_coefficients(form)) & set(uc.subfunctions): - raise ValueError("Taking derivative of form wrt u, but form contains coefficients from u.subfunctions." - "\nYou probably meant to write split(u) when defining your form.") - - mesh = as_domain(form) - if not mesh: - raise ValueError("Expression to be differentiated has no ufl domain." - "\nDo you need to add a domain to your Constant?") - is_dX = u_is_x or u is mesh.coordinates - try: args = form.arguments() except AttributeError: args = extract_arguments(form) # UFL arguments need unique indices within a form n = max(a.number() for a in args) if args else -1 - - if is_dX: - coords = mesh.coordinates - u = ufl.SpatialCoordinate(mesh) + set_internal_coord_derivatives = False + all_meshes = extract_domains(form) + if isinstance(u, ufl.SpatialCoordinate): + uc = u + coords_mesh, = extract_unique_domain(u) + coords = coords_mesh.coordinates + V = coords.function_space() + set_internal_coord_derivatives = True + elif any(u is m.coordinates for m in all_meshes): + uc = u + coords = u + coord_mesh = u.function_space().mesh() + u = ufl.SpatialCoordinate(coord_mesh) V = coords.function_space() - elif isinstance(uc, (firedrake.Function, firedrake.Cofunction, BaseFormOperator)): + set_internal_coord_derivatives = True + elif isinstance(u, BaseFormOperator): + uc = u V = uc.function_space() - elif isinstance(uc, firedrake.Constant): + elif isinstance(u, Constant): + uc = u if uc.ufl_shape != (): raise ValueError("Real function space of vector elements not supported") # Replace instances of the constant with a new argument ``x`` # and differentiate wrt ``x``. + mesh = as_domain(form) # integration domain V = firedrake.FunctionSpace(mesh, "Real", 0) x = ufl.Coefficient(V) # TODO: Update this line when https://github.com/FEniCS/ufl/issues/171 is fixed form = ufl.replace(form, {u: x}) u_orig, u = u, x else: - raise RuntimeError("Can't compute derivative for form") - + uc, = extract_coefficients(u) + if not isinstance(uc, (firedrake.Function, firedrake.Cofunction)): + raise RuntimeError(f"Can't compute derivative for form w.r.t {u}") + if len(uc.subfunctions) > 1 and set(extract_coefficients(form)) & set(uc.subfunctions): + raise ValueError("Taking derivative of form wrt u, but form contains coefficients from u.subfunctions." + "\nYou probably meant to write split(u) when defining your form.") + V = uc.function_space() if du is None: du = Argument(V, n + 1) - - if is_dX: + if set_internal_coord_derivatives: internal_coefficient_derivatives = {coords: du} else: internal_coefficient_derivatives = {} if coefficient_derivatives: internal_coefficient_derivatives.update(coefficient_derivatives) - if u.ufl_shape != du.ufl_shape: raise ValueError("Shapes of u and du do not match.\n" "If you passed an indexed part of split(u) into " "derivative, you need to provide an appropriate du as well.") dform = ufl.derivative(form, u, du, internal_coefficient_derivatives) - if isinstance(uc, firedrake.Constant): + if isinstance(uc, Constant): # If we replaced constants with ``x`` to differentiate, # replace them back to the original symbolic constant dform = ufl.replace(dform, {u: u_orig}) @@ -366,23 +366,37 @@ def FacetNormal(mesh): return ufl.FacetNormal(mesh) -def extract_domains(func): - """Extract the domain from `func`. +def extract_domains(f): + """Extract the domain from `f`. Parameters ---------- - x : firedrake.function.Function, firedrake.cofunction.Cofunction, or firedrake.constant.Constant - The function to extract the domain from. + f : ufl.form.Form or firedrake.slate.TensorBase or firedrake.function.Function or firedrake.cofunction.Cofunction or firedrake.constant.Constant + The form, tensor, or function to extract the domain from. Returns ------- list of firedrake.mesh.MeshGeometry Extracted domains. """ - if isinstance(func, (function.Function, cofunction.Cofunction, Argument, Coargument)): - return [func.function_space().mesh()] + from firedrake.mesh import MeshSequenceGeometry + + if isinstance(f, firedrake.slate.TensorBase): + return f.ufl_domains() + elif isinstance(f, (cofunction.Cofunction, Coargument)): + # ufl.domain.extract_domains does not work. + mesh = f.function_space().mesh() + if isinstance(mesh, MeshSequenceGeometry): + return list(set(mesh._meshes)) + else: + return [mesh] + elif isinstance(f, (ufl.form.FormSum, ufl.Action)): + # ufl.domain.extract_domains does not work. + if f._domains is None: + f._analyze_domains() + return f._domains else: - return ufl.domain.extract_domains(func) + return ufl.domain.extract_domains(f) def extract_unique_domain(func): @@ -390,7 +404,7 @@ def extract_unique_domain(func): Parameters ---------- - x : firedrake.function.Function, firedrake.cofunction.Cofunction, or firedrake.constant.Constant + func : firedrake.function.Function, firedrake.cofunction.Cofunction, or firedrake.constant.Constant The function to extract the domain from. Returns @@ -399,6 +413,6 @@ def extract_unique_domain(func): Extracted domains. """ if isinstance(func, (function.Function, cofunction.Cofunction, Argument, Coargument)): - return func.function_space().mesh() + return func.function_space().mesh().unique() else: return ufl.domain.extract_unique_domain(func) diff --git a/tests/firedrake/regression/test_assemble_baseform.py b/tests/firedrake/regression/test_assemble_baseform.py index 64093ce525..ceb1abc47e 100644 --- a/tests/firedrake/regression/test_assemble_baseform.py +++ b/tests/firedrake/regression/test_assemble_baseform.py @@ -133,7 +133,7 @@ def test_scalar_formsum(f, scale): s2 = Constant(s2) elif scale == "Real": mesh = f.function_space().mesh() - R = FunctionSpace(mesh, "R", 0) + R = FunctionSpace(mesh.unique(), "R", 0) s1 = Function(R, val=s1) s2 = Function(R, val=s2) @@ -142,7 +142,7 @@ def test_scalar_formsum(f, scale): res2 = assemble(formsum) assert res2 == expected - mesh = f.function_space().mesh() + mesh = f.function_space().mesh().unique() R = FunctionSpace(mesh, "R", 0) tensor = Cofunction(R.dual()) diff --git a/tests/firedrake/regression/test_function_spaces.py b/tests/firedrake/regression/test_function_spaces.py index 3445f8773a..ab22b66ddb 100644 --- a/tests/firedrake/regression/test_function_spaces.py +++ b/tests/firedrake/regression/test_function_spaces.py @@ -249,8 +249,8 @@ def test_reconstruct_variant(family, dual): def test_reconstruct_mixed(fs, mesh, mesh2, dual): W1 = fs.dual() if dual else fs W2 = W1.reconstruct(mesh=mesh2) - assert W1.mesh() == mesh - assert W2.mesh() == mesh2 + assert W1.mesh().unique() == mesh + assert W2.mesh().unique() == mesh2 assert W1.ufl_element() == W2.ufl_element() for index, V in enumerate(W1): V1 = W1.sub(index) diff --git a/tests/firedrake/regression/test_multiple_domains.py b/tests/firedrake/regression/test_multiple_domains.py index 86766aba79..591dc9180a 100644 --- a/tests/firedrake/regression/test_multiple_domains.py +++ b/tests/firedrake/regression/test_multiple_domains.py @@ -80,10 +80,14 @@ def test_functional(mesh1, mesh2): assert np.allclose(val, cell_volume * (1 + 0.5**mesh1.topological_dimension())) +def cell_measure(primal, secondary): + return Measure("dx", primal, intersect_measures=(Measure("dx", secondary),)) + + @pytest.mark.parametrize("form,expect", [ (lambda v, mesh1, mesh2: conj(v)*dx(domain=mesh1), lambda vol, dim: vol), - (lambda v, mesh1, mesh2: conj(v)*dx(domain=mesh2), lambda vol, dim: vol*(0.5**dim)), - (lambda v, mesh1, mesh2: conj(v)*dx(domain=mesh1) + conj(v)*dx(domain=mesh2), lambda vol, dim: vol*(1 + 0.5**dim)) + (lambda v, mesh1, mesh2: conj(v)*cell_measure(mesh2, mesh1), lambda vol, dim: vol*(0.5**dim)), + (lambda v, mesh1, mesh2: conj(v)*dx(domain=mesh1) + conj(v)*cell_measure(mesh2, mesh1), lambda vol, dim: vol*(1 + 0.5**dim)) ], ids=["conj(v)*dx(mesh1)", "conj(v)*dx(mesh2)", "conj(v)*(dx(mesh1) + dx(mesh2)"]) def test_one_form(mesh1, mesh2, form, expect): V = FunctionSpace(mesh1, "DG", 0) @@ -102,8 +106,8 @@ def test_one_form(mesh1, mesh2, form, expect): @pytest.mark.parametrize("form,expect", [ (lambda u, v, mesh1, mesh2: inner(u, v)*dx(domain=mesh1), lambda vol, dim: vol), - (lambda u, v, mesh1, mesh2: inner(u, v)*dx(domain=mesh2), lambda vol, dim: vol*(0.5**dim)), - (lambda u, v, mesh1, mesh2: inner(u, v)*dx(domain=mesh1) + inner(u, v)*dx(domain=mesh2), lambda vol, dim: vol*(1 + 0.5**dim)) + (lambda u, v, mesh1, mesh2: inner(u, v)*cell_measure(mesh2, mesh1), lambda vol, dim: vol*(0.5**dim)), + (lambda u, v, mesh1, mesh2: inner(u, v)*dx(domain=mesh1) + inner(u, v)*cell_measure(mesh2, mesh1), lambda vol, dim: vol*(1 + 0.5**dim)) ], ids=["inner(u, v)*dx(mesh1)", "inner(u, v)*dx(mesh2)", "inner(u, v)*(dx(mesh1) + dx(mesh2)"]) def test_two_form(mesh1, mesh2, form, expect): V = FunctionSpace(mesh1, "DG", 0) diff --git a/tests/firedrake/submesh/test_submesh_assemble.py b/tests/firedrake/submesh/test_submesh_assemble.py new file mode 100644 index 0000000000..7a3e0136e3 --- /dev/null +++ b/tests/firedrake/submesh/test_submesh_assemble.py @@ -0,0 +1,328 @@ +import numpy as np +from firedrake import * + + +def test_submesh_assemble_cell_cell_integral_cell(): + dim = 2 + mesh = RectangleMesh(2, 1, 2., 1., quadrilateral=True) + x, y = SpatialCoordinate(mesh) + DQ0 = FunctionSpace(mesh, "DQ", 0) + indicator_function = Function(DQ0).interpolate(conditional(x > 1., 1, 0)) + mesh.mark_entities(indicator_function, 999) + subm = Submesh(mesh, dim, 999) + V0 = FunctionSpace(mesh, "CG", 1) + V1 = FunctionSpace(subm, "CG", 1) + V = V0 * V1 + u = TrialFunction(V) + v = TestFunction(V) + u0, u1 = split(u) + v0, v1 = split(v) + dx0 = Measure("dx", domain=mesh, intersect_measures=(Measure("dx", subm),)) + dx1 = Measure("dx", domain=subm, intersect_measures=(Measure("dx", mesh),)) + a = inner(u1, v0) * dx0(999) + inner(u0, v1) * dx1 + A = assemble(a, mat_type="nest") + assert np.allclose(A.M.sparsity[0][0].nnz, [1, 1, 1, 1, 1, 1]) # bc nodes + assert np.allclose(A.M.sparsity[0][1].nnz, [4, 4, 4, 4, 0, 0]) + assert np.allclose(A.M.sparsity[1][0].nnz, [4, 4, 4, 4]) + assert np.allclose(A.M.sparsity[1][1].nnz, [1, 1, 1, 1]) # bc nodes + M10 = np.array([[1./9. , 1./18., 1./36., 1./18., 0., 0.], # noqa: E203 + [1./18., 1./9. , 1./18., 1./36., 0., 0.], # noqa: E203 + [1./36., 1./18., 1./9. , 1./18., 0., 0.], # noqa: E203 + [1./18., 1./36., 1./18., 1./9. , 0., 0.]]) # noqa: E203 + assert np.allclose(A.M[0][1].values, np.transpose(M10)) + assert np.allclose(A.M[1][0].values, M10) + + +def test_submesh_assemble_cell_cell_integral_facet(): + dim = 2 + mesh = RectangleMesh(2, 1, 2., 1., quadrilateral=True) + x, y = SpatialCoordinate(mesh) + DQ0 = FunctionSpace(mesh, "DQ", 0) + indicator_function = Function(DQ0).interpolate(conditional(x > 1., 1, 0)) + mesh.mark_entities(indicator_function, 999) + subm = Submesh(mesh, dim, 999) + V0 = FunctionSpace(mesh, "DQ", 1, variant="equispaced") + V1 = FunctionSpace(subm, "DQ", 1, variant="equispaced") + V = V0 * V1 + u = TrialFunction(V) + v = TestFunction(V) + u0, u1 = split(u) + v0, v1 = split(v) + dS0 = Measure("dS", domain=mesh, intersect_measures=(Measure("ds", subm),)) + ds1 = Measure("ds", domain=subm, intersect_measures=(Measure("dS", mesh),)) + a = inner(u1, v0('+')) * dS0 + inner(u0('+'), v1) * ds1(5) + A = assemble(a, mat_type="nest") + assert np.allclose(A.M.sparsity[0][0].nnz, [1, 1, 1, 1, 1, 1, 1, 1]) # bc nodes + assert np.allclose(A.M.sparsity[0][1].nnz, [4, 4, 4, 4, 4, 4, 4, 4]) + assert np.allclose(A.M.sparsity[1][0].nnz, [8, 8, 8, 8]) + assert np.allclose(A.M.sparsity[1][1].nnz, [1, 1, 1, 1]) # bc nodes + M10 = np.array([[0., 0., 0., 0., 0., 0., 1. / 3., 1. / 6.], + [0., 0., 0., 0., 0., 0., 1. / 6., 1. / 3.], + [0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0.]]) + assert np.allclose(A.M[0][1].values, np.transpose(M10)) + assert np.allclose(A.M[1][0].values, M10) + b = inner(u1, v0('+')) * ds1(5) + inner(u0('+'), v1) * dS0 + B = assemble(b, mat_type="nest") + assert np.allclose(B.M.sparsity[0][0].nnz, [1, 1, 1, 1, 1, 1, 1, 1]) # bc nodes + assert np.allclose(B.M.sparsity[0][1].nnz, [4, 4, 4, 4, 4, 4, 4, 4]) + assert np.allclose(B.M.sparsity[1][0].nnz, [8, 8, 8, 8]) + assert np.allclose(B.M.sparsity[1][1].nnz, [1, 1, 1, 1]) # bc nodes + assert np.allclose(B.M[0][1].values, A.M[0][1].values) + assert np.allclose(B.M[1][0].values, A.M[1][0].values) + + +def test_submesh_assemble_cell_cell_cell_cell_integral_various(): + # +-------+-------+-------+-------+ + # | | | | | + # | | 555 | | mesh + # | | | | | + # +-------+-------+-------+-------+ + # +-------+-------+ + # | | | + # | | 555 mesh_l + # | | | + # +-------+-------+ + # +-------+-------+ + # | | | + # 555 | | mesh_r + # | | | + # +-------+-------+ + # +-------+ + # | | + # 555 | mesh_rl + # | | + # +-------+ + dim = 2 + mesh = RectangleMesh(4, 1, 4., 1., quadrilateral=True) + x, y = SpatialCoordinate(mesh) + label_int = 555 + label_l = 81100 + label_r = 80011 + label_rl = 80010 + HDivTrace0 = FunctionSpace(mesh, "HDiv Trace", 0) + DG0 = FunctionSpace(mesh, "DG", 0) + f_int = Function(HDivTrace0).interpolate(conditional(And(x > 1.9, x < 2.1), 1, 0)) + f_l = Function(DG0).interpolate(conditional(x < 2., 1, 0)) + f_r = Function(DG0).interpolate(conditional(x > 2., 1, 0)) + f_rl = Function(DG0).interpolate(conditional(And(x > 2., x < 3.), 1, 0)) + mesh = RelabeledMesh(mesh, [f_int, f_l, f_r, f_rl], [label_int, label_l, label_r, label_rl]) + x, y = SpatialCoordinate(mesh) + mesh_l = Submesh(mesh, dim, label_l) + mesh_r = Submesh(mesh, dim, label_r) + mesh_rl = Submesh(mesh_r, dim, label_rl) + dS = Measure( + "dS", domain=mesh, + intersect_measures=( + Measure("ds", mesh_l), + Measure("ds", mesh_r), + Measure("ds", mesh_rl), + ) + ) + ds_l = Measure( + "ds", domain=mesh_l, + intersect_measures=( + Measure("dS", mesh), + Measure("ds", mesh_r), + Measure("ds", mesh_rl), + ) + ) + ds_r = Measure( + "ds", domain=mesh_r, + intersect_measures=( + Measure("dS", mesh), + Measure("ds", mesh_l), + Measure("ds", mesh_rl), + ) + ) + ds_rl = Measure( + "ds", domain=mesh_rl, + intersect_measures=( + Measure("dS", mesh), + Measure("ds", mesh_l), + Measure("ds", mesh_r), + ) + ) + n_l = FacetNormal(mesh_l) + n_rl = FacetNormal(mesh_rl) + assert assemble(dot(n_rl + n_l, n_rl + n_l) * ds_rl(label_int)) < 1.e-32 + assert assemble(dot(n_rl + n_l, n_rl + n_l) * ds_r(label_int)) < 1.e-32 + assert assemble(dot(n_rl + n_l, n_rl + n_l) * ds_l(label_int)) < 1.e-32 + assert assemble(dot(n_rl + n_l, n_rl + n_l) * dS(label_int)) < 1.e-32 + V_l = FunctionSpace(mesh_l, "DQ", 1, variant='equispaced') + V_rl = FunctionSpace(mesh_rl, "DQ", 1, variant='equispaced') + V = V_l * V_rl + u_l, u_rl = TrialFunctions(V) + v_l, v_rl = TestFunctions(V) + a = inner(u_rl, v_l) * ds_l(label_int) + inner(u_l, v_rl) * ds_rl(label_int) + A = assemble(a, mat_type="nest") + assert np.allclose(A.M.sparsity[0][0].nnz, [1, 1, 1, 1, 1, 1, 1, 1]) # bc nodes + assert np.allclose(A.M.sparsity[0][1].nnz, [4, 4, 4, 4, 0, 0, 0, 0]) + assert np.allclose(A.M.sparsity[1][0].nnz, [4, 4, 4, 4]) + assert np.allclose(A.M.sparsity[1][1].nnz, [1, 1, 1, 1]) # bc nodes + M10 = np.array([[0., 0., 1. / 3., 1. / 6., 0., 0., 0., 0.], + [0., 0., 1. / 6., 1. / 3., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0.]]) + assert np.allclose(A.M[0][1].values, np.transpose(M10)) + assert np.allclose(A.M[1][0].values, M10) + b = inner(u_rl, v_l) * dS(label_int) + inner(u_l, v_rl) * dS(label_int) + B = assemble(b, mat_type="nest") + assert np.allclose(B.M.sparsity[0][0].nnz, [1, 1, 1, 1, 1, 1, 1, 1]) # bc nodes + assert np.allclose(B.M.sparsity[0][1].nnz, [4, 4, 4, 4, 0, 0, 0, 0]) + assert np.allclose(B.M.sparsity[1][0].nnz, [4, 4, 4, 4]) + assert np.allclose(B.M.sparsity[1][1].nnz, [1, 1, 1, 1]) # bc nodes + assert np.allclose(B.M[0][1].values, A.M[0][1].values) + assert np.allclose(B.M[1][0].values, A.M[1][0].values) + + +def test_submesh_assemble_cell_cell_cell_cell_integral_avg(): + # +-------+-------+-------+-------+ + # | | | | | + # | | 555 | | mesh + # | | | | | + # +-------+-------+-------+-------+ + # +-------+-------+-------+ + # | | | | + # | | 555 | mesh_l + # | | | | + # +-------+-------+-------+ + # +-------+-------+ + # | | | + # 555 | | mesh_r + # | | | + # +-------+-------+ + # +-------+ + # | | + # 555 | mesh_rl + # | | + # +-------+ + dim = 2 + mesh = RectangleMesh(4, 1, 4., 1., quadrilateral=True) + x, y = SpatialCoordinate(mesh) + label_int = 555 + label_l = 81110 + label_r = 80011 + label_rl = 80010 + HDivTrace0 = FunctionSpace(mesh, "HDiv Trace", 0) + DG0 = FunctionSpace(mesh, "DG", 0) + f_int = Function(HDivTrace0).interpolate(conditional(And(x > 1.9, x < 2.1), 1, 0)) + f_l = Function(DG0).interpolate(conditional(x < 3., 1, 0)) + f_r = Function(DG0).interpolate(conditional(x > 2., 1, 0)) + f_rl = Function(DG0).interpolate(conditional(And(x > 2., x < 3.), 1, 0)) + mesh = RelabeledMesh(mesh, [f_int, f_l, f_r, f_rl], [label_int, label_l, label_r, label_rl]) + x, y = SpatialCoordinate(mesh) + mesh_l = Submesh(mesh, dim, label_l) + x_l, y_l = SpatialCoordinate(mesh_l) + mesh_r = Submesh(mesh, dim, label_r) + x_r, y_r = SpatialCoordinate(mesh_r) + mesh_rl = Submesh(mesh_r, dim, label_rl) + x_rl, y_rl = SpatialCoordinate(mesh_rl) + dx = Measure( + "dx", domain=mesh, + intersect_measures=( + Measure("dx", mesh_l), + Measure("dx", mesh_r), + Measure("dx", mesh_rl), + ) + ) + dx_l = Measure( + "dx", domain=mesh_l, + intersect_measures=( + Measure("dx", mesh), + Measure("dx", mesh_r), + Measure("dx", mesh_rl), + ) + ) + dx_rl = Measure( + "dx", domain=mesh_rl, + intersect_measures=( + Measure("dx", mesh), + Measure("dx", mesh_l), + Measure("dx", mesh_r), + ) + ) + dS = Measure( + "dS", domain=mesh, + intersect_measures=( + Measure("dS", mesh_l), + Measure("ds", mesh_r), + Measure("ds", mesh_rl), + ) + ) + dS_l = Measure( + "dS", domain=mesh_l, + intersect_measures=( + Measure("dS", mesh), + Measure("ds", mesh_r), + Measure("ds", mesh_rl), + ) + ) + ds_rl = Measure( + "ds", domain=mesh_rl, + intersect_measures=( + Measure("dS", mesh), + Measure("dS", mesh_l), + Measure("ds", mesh_r), + ) + ) + assert abs(assemble(cell_avg(x) * dx(label_rl)) - 2.5) < 5.e-16 + assert abs(assemble(cell_avg(x) * dx_rl) - 2.5) < 5.e-16 + assert abs(assemble(cell_avg(x_rl) * dx(label_rl)) - 2.5) < 5.e-16 + assert abs(assemble(cell_avg(x_rl) * dx_l(label_rl)) - 2.5) < 5.e-16 + assert abs(assemble(cell_avg(x_l) * dx_rl) - 2.5) < 5.e-16 + assert abs(assemble(facet_avg(y * y) * dS(label_int)) - 1. / 3.) < 5.e-16 + assert abs(assemble(facet_avg(y('+') * y('-')) * ds_rl(label_int)) - 1. / 3.) < 5.e-16 + assert abs(assemble(facet_avg(y_rl * y_rl) * dS(label_int)) - 1. / 3.) < 5.e-16 + assert abs(assemble(facet_avg(y_rl * y_rl) * dS_l(label_int)) - 1. / 3.) < 5.e-16 + assert abs(assemble(facet_avg(y_l('+') * y_l('-')) * ds_rl(label_int)) - 1. / 3.) < 5.e-16 + + +def test_submesh_assemble_cell_cell_equation_bc(): + dim = 2 + mesh = RectangleMesh(2, 1, 2., 1., quadrilateral=True) + x, y = SpatialCoordinate(mesh) + label_int = 555 + label_l = 810 + label_r = 801 + HDivTrace0 = FunctionSpace(mesh, "HDiv Trace", 0) + DQ0 = FunctionSpace(mesh, "DQ", 0) + f_int = Function(HDivTrace0).interpolate(conditional(And(x > 0.9, x < 1.1), 1, 0)) + f_l = Function(DQ0).interpolate(conditional(x < 1., 1, 0)) + f_r = Function(DQ0).interpolate(conditional(x > 1., 1, 0)) + mesh = RelabeledMesh(mesh, [f_int, f_l, f_r], [label_int, label_l, label_r]) + mesh_l = Submesh(mesh, dim, label_l) + mesh_r = Submesh(mesh, dim, label_r) + V_l = FunctionSpace(mesh_l, "CG", 1) + V_r = FunctionSpace(mesh_r, "CG", 1) + V = V_l * V_r + u = TrialFunction(V) + v = TestFunction(V) + u_l, u_r = split(u) + v_l, v_r = split(v) + dx_l = Measure("dx", domain=mesh_l) + ds_l = Measure("ds", domain=mesh_l, intersect_measures=(Measure("ds", mesh_r),)) + a = inner(u_l, v_l) * dx_l + a_int = inner(u_l - u_r, v_l) * ds_l(label_int) + L_int = inner(Constant(0), v_l) * ds_l(label_int) + sol = Function(V) + bc = EquationBC(a_int == L_int, sol, label_int, V=V.sub(0)) + A = assemble(a, bcs=bc.extract_form('J'), mat_type="nest") + assert np.allclose(Function(V_l).interpolate(SpatialCoordinate(mesh_l)[0]).dat.data, [0., 0., 1., 1.]) + assert np.allclose(Function(V_l).interpolate(SpatialCoordinate(mesh_l)[1]).dat.data, [0., 1., 1., 0.]) + assert np.allclose(Function(V_r).interpolate(SpatialCoordinate(mesh_r)[0]).dat.data, [1., 1., 2., 2.]) + assert np.allclose(Function(V_r).interpolate(SpatialCoordinate(mesh_r)[1]).dat.data, [0., 1., 1., 0.]) + assert np.allclose(A.M.sparsity[0][0].nnz, [4, 4, 4, 4]) + assert np.allclose(A.M.sparsity[0][1].nnz, [4, 4, 4, 4]) + assert np.allclose(A.M.sparsity[1][0].nnz, [0, 0, 0, 0]) + assert np.allclose(A.M.sparsity[1][1].nnz, [1, 1, 1, 1]) # bc nodes + M00 = np.array([[1. / 9. , 1. / 18., 1. / 36., 1. / 18.], # noqa: E203 + [1. / 18., 1. / 9. , 1. / 18., 1. / 36.], # noqa: E203 + [0., 0., 1. / 3., 1. / 6.], + [0., 0., 1. / 6., 1. / 3.]]) + M01 = np.array([[0., 0., 0., 0.], + [0., 0., 0., 0.], + [- 1. / 6., - 1. / 3., 0., 0.], + [- 1. / 3., - 1. / 6., 0., 0.]]) + assert np.allclose(A.M[0][0].values, M00) + assert np.allclose(A.M[0][1].values, M01) diff --git a/tests/firedrake/submesh/test_submesh_base.py b/tests/firedrake/submesh/test_submesh_base.py new file mode 100644 index 0000000000..910eebf934 --- /dev/null +++ b/tests/firedrake/submesh/test_submesh_base.py @@ -0,0 +1,275 @@ +import pytest +from firedrake import * + + +def _get_expr(m): + if m.geometric_dimension() == 1: + x, = SpatialCoordinate(m) + y = x * x + z = x + y + elif m.geometric_dimension() == 2: + x, y = SpatialCoordinate(m) + z = x + y + elif m.geometric_dimension() == 3: + x, y, z = SpatialCoordinate(m) + else: + raise NotImplementedError("Not implemented") + return exp(x + y * y + z * z * z) + + +def _test_submesh_base_cell_integral_quad(family_degree, nelem): + dim = 2 + family, degree = family_degree + mesh = UnitSquareMesh(nelem, nelem, quadrilateral=True) + V = FunctionSpace(mesh, family, degree) + f = Function(V).interpolate(_get_expr(mesh)) + x, y = SpatialCoordinate(mesh) + cond = conditional(x > .5, 1, + conditional(y > .5, 1, 0)) # noqa: E128 + target = assemble(f * cond * dx) + DQ0 = FunctionSpace(mesh, "DQ", 0) + indicator_function = Function(DQ0).interpolate(cond) + label_value = 999 + mesh.mark_entities(indicator_function, label_value) + msub = Submesh(mesh, dim, label_value) + Vsub = FunctionSpace(msub, family, degree) + fsub = Function(Vsub).interpolate(_get_expr(msub)) + result = assemble(fsub * dx) + assert abs(result - target) < 1e-12 + + +@pytest.mark.parametrize('family_degree', [("Q", 4), ]) +@pytest.mark.parametrize('nelem', [2, 4, 8, 16]) +def test_submesh_base_cell_integral_quad_1_process(family_degree, nelem): + _test_submesh_base_cell_integral_quad(family_degree, nelem) + + +@pytest.mark.parallel(nprocs=2) +@pytest.mark.parametrize('family_degree', [("Q", 4), ]) +@pytest.mark.parametrize('nelem', [2, 4, 8, 16]) +def test_submesh_base_cell_integral_quad_2_processes(family_degree, nelem): + _test_submesh_base_cell_integral_quad(family_degree, nelem) + + +@pytest.mark.parallel(nprocs=3) +@pytest.mark.parametrize('family_degree', [("Q", 4), ]) +@pytest.mark.parametrize('nelem', [2, 4, 8, 16]) +def test_submesh_base_cell_integral_quad_3_processes(family_degree, nelem): + _test_submesh_base_cell_integral_quad(family_degree, nelem) + + +@pytest.mark.parallel(nprocs=4) +@pytest.mark.parametrize('family_degree', [("Q", 4), ]) +@pytest.mark.parametrize('nelem', [2, 4, 8, 16]) +def test_submesh_base_cell_integral_quad_4_processes(family_degree, nelem): + _test_submesh_base_cell_integral_quad(family_degree, nelem) + + +def _test_submesh_base_facet_integral_quad(family_degree, nelem): + dim = 2 + family, degree = family_degree + mesh = UnitSquareMesh(nelem, nelem, quadrilateral=True) + x, y = SpatialCoordinate(mesh) + cond = conditional(x > .5, 1, + conditional(y > .5, 1, 0)) # noqa: E128 + DQ0 = FunctionSpace(mesh, "DQ", 0) + indicator_function = Function(DQ0).interpolate(cond) + label_value = 999 + mesh.mark_entities(indicator_function, label_value) + subm = Submesh(mesh, dim, label_value) + for i in [1, 2, 3, 4]: + target = assemble(cond * _get_expr(mesh) * ds(i)) + result = assemble(_get_expr(subm) * ds(i)) + assert abs(result - target) < 2e-12 + # Check new boundary. + assert abs(assemble(Constant(1.) * ds(subdomain_id=5, domain=subm)) - 1.0) < 1e-12 + x, y = SpatialCoordinate(subm) + assert abs(assemble(x**4 * ds(5)) - (.5**5 / 5 + .5**4 * .5)) < 1e-12 + assert abs(assemble(y**4 * ds(5)) - (.5**5 / 5 + .5**4 * .5)) < 1e-12 + + +@pytest.mark.parametrize('family_degree', [("Q", 3), ]) +@pytest.mark.parametrize('nelem', [2, 4, 8, 16]) +def test_submesh_base_facet_integral_quad_1_process(family_degree, nelem): + _test_submesh_base_facet_integral_quad(family_degree, nelem) + + +@pytest.mark.parallel(nprocs=2) +@pytest.mark.parametrize('family_degree', [("Q", 3), ]) +@pytest.mark.parametrize('nelem', [2, 4, 8, 16]) +def test_submesh_base_facet_integral_quad_2_processes(family_degree, nelem): + _test_submesh_base_facet_integral_quad(family_degree, nelem) + + +@pytest.mark.parallel(nprocs=3) +@pytest.mark.parametrize('family_degree', [("Q", 3), ]) +@pytest.mark.parametrize('nelem', [2, 4, 8, 16]) +def test_submesh_base_facet_integral_quad_3_processes(family_degree, nelem): + _test_submesh_base_facet_integral_quad(family_degree, nelem) + + +@pytest.mark.parallel(nprocs=4) +@pytest.mark.parametrize('family_degree', [("Q", 3), ]) +@pytest.mark.parametrize('nelem', [2, 4, 8, 16]) +def test_submesh_base_facet_integral_quad_4_processes(family_degree, nelem): + _test_submesh_base_facet_integral_quad(family_degree, nelem) + + +def _test_submesh_base_cell_integral_hex(family_degree, nelem): + dim = 3 + family, degree = family_degree + mesh = UnitCubeMesh(nelem, nelem, nelem, hexahedral=True) + V = FunctionSpace(mesh, family, degree) + f = Function(V).interpolate(_get_expr(mesh)) + x, y, z = SpatialCoordinate(mesh) + cond = conditional(x > .5, 1, + conditional(y > .5, 1, # noqa: E128 + conditional(z > .5, 1, 0))) # noqa: E128 + target = assemble(f * cond * dx) + DQ0 = FunctionSpace(mesh, "DQ", 0) + indicator_function = Function(DQ0).interpolate(cond) + label_value = 999 + mesh.mark_entities(indicator_function, label_value) + msub = Submesh(mesh, dim, label_value) + Vsub = FunctionSpace(msub, family, degree) + fsub = Function(Vsub).interpolate(_get_expr(msub)) + result = assemble(fsub * dx) + assert abs(result - target) < 1e-12 + + +@pytest.mark.parametrize('family_degree', [("Q", 4), ]) +@pytest.mark.parametrize('nelem', [2, 4, 8]) +def test_submesh_base_cell_integral_hex_1_process(family_degree, nelem): + _test_submesh_base_cell_integral_hex(family_degree, nelem) + + +@pytest.mark.parallel(nprocs=2) +@pytest.mark.parametrize('family_degree', [("Q", 4), ]) +@pytest.mark.parametrize('nelem', [2, 4, 8]) +def test_submesh_base_cell_integral_hex_2_processes(family_degree, nelem): + _test_submesh_base_cell_integral_hex(family_degree, nelem) + + +@pytest.mark.parallel(nprocs=4) +@pytest.mark.parametrize('family_degree', [("Q", 4), ]) +@pytest.mark.parametrize('nelem', [2, 4, 8]) +def test_submesh_base_cell_integral_hex_4_processes(family_degree, nelem): + _test_submesh_base_cell_integral_hex(family_degree, nelem) + + +def _test_submesh_base_facet_integral_hex(family_degree, nelem): + dim = 3 + family, degree = family_degree + mesh = UnitCubeMesh(nelem, nelem, nelem, hexahedral=True) + x, y, z = SpatialCoordinate(mesh) + cond = conditional(x > .5, 1, + conditional(y > .5, 1, # noqa: E128 + conditional(z > .5, 1, 0))) # noqa: E128 + DQ0 = FunctionSpace(mesh, "DQ", 0) + indicator_function = Function(DQ0).interpolate(cond) + label_value = 999 + mesh.mark_entities(indicator_function, label_value) + subm = Submesh(mesh, dim, label_value) + for i in [1, 2, 3, 4, 5, 6]: + target = assemble(cond * _get_expr(mesh) * ds(i)) + result = assemble(_get_expr(subm) * ds(i)) + assert abs(result - target) < 2e-12 + # Check new boundary. + assert abs(assemble(Constant(1) * ds(subdomain_id=7, domain=subm)) - .75) < 1e-12 + x, y, z = SpatialCoordinate(subm) + assert abs(assemble(x**4 * ds(7)) - (.5**5 / 5 * .5 * 2 + .5**4 * .5**2)) < 1e-12 + assert abs(assemble(y**4 * ds(7)) - (.5**5 / 5 * .5 * 2 + .5**4 * .5**2)) < 1e-12 + assert abs(assemble(z**4 * ds(7)) - (.5**5 / 5 * .5 * 2 + .5**4 * .5**2)) < 1e-12 + + +@pytest.mark.parametrize('family_degree', [("Q", 3), ]) +@pytest.mark.parametrize('nelem', [2, 4, 8]) +def test_submesh_base_facet_integral_hex_1_process(family_degree, nelem): + _test_submesh_base_facet_integral_hex(family_degree, nelem) + + +@pytest.mark.parallel(nprocs=2) +@pytest.mark.parametrize('family_degree', [("Q", 3), ]) +@pytest.mark.parametrize('nelem', [2, 4, 8]) +def test_submesh_base_facet_integral_hex_2_processes(family_degree, nelem): + _test_submesh_base_facet_integral_hex(family_degree, nelem) + + +@pytest.mark.parallel(nprocs=4) +@pytest.mark.parametrize('family_degree', [("Q", 3), ]) +@pytest.mark.parametrize('nelem', [2, 4, 8]) +def test_submesh_base_facet_integral_hex_4_processes(family_degree, nelem): + _test_submesh_base_facet_integral_hex(family_degree, nelem) + + +@pytest.mark.parallel(nprocs=2) +def test_submesh_base_entity_maps(): + + # 3---9--(5)-(12)(7) (7)-(13)-3---9---5 + # | | | | | | + # 8 0 (11) (1) (13) (12) (1) 8 0 10 mesh + # | | | | | | + # 2--10--(4)(14)-(6) (6)-(14)-2--11---4 + # + # 2---6---4 (4)-(7)-(2) + # | | | | + # 5 0 8 (6) (0) (5) submesh + # | | | | + # 1---7---3 (3)-(8)-(1) + # + # rank 0 rank 1 + + dim = 2 + mesh = RectangleMesh(2, 1, 2., 1., quadrilateral=True, distribution_parameters={"partitioner_type": "simple"}) + assert mesh.comm.size == 2 + rank = mesh.comm.rank + x, y = SpatialCoordinate(mesh) + DQ0 = FunctionSpace(mesh, "DQ", 0) + indicator_function = Function(DQ0).interpolate(conditional(x < 1., 1, 0)) + label_value = 999 + mesh.mark_entities(indicator_function, label_value) + submesh = Submesh(mesh, dim, label_value) + submesh.topology_dm.viewFromOptions("-dm_view") + subdm = submesh.topology.topology_dm + if rank == 0: + assert subdm.getLabel("pyop2_core").getStratumSize(1) == 0 + assert subdm.getLabel("pyop2_owned").getStratumSize(1) == 9 + assert subdm.getLabel("pyop2_ghost").getStratumSize(1) == 0 + assert (subdm.getLabel("pyop2_owned").getStratumIS(1).getIndices() == np.array([0, 1, 2, 3, 4, 5, 6, 7, 8])).all() + assert (mesh.interior_facets.facets == np.array([11])).all + assert (mesh.exterior_facets.facets == np.array([8, 9, 10, 12, 13, 14])).all + assert (submesh.interior_facets.facets == np.array([])).all + assert (submesh.exterior_facets.facets == np.array([5, 6, 8, 7])).all() + else: + assert subdm.getLabel("pyop2_core").getStratumSize(1) == 0 + assert subdm.getLabel("pyop2_owned").getStratumSize(1) == 0 + assert subdm.getLabel("pyop2_ghost").getStratumSize(1) == 9 + assert (subdm.getLabel("pyop2_ghost").getStratumIS(1).getIndices() == np.array([0, 1, 2, 3, 4, 5, 6, 7, 8])).all() + assert (mesh.interior_facets.facets == np.array([8])).all + assert (mesh.exterior_facets.facets == np.array([9, 10, 11, 12, 13, 14])).all + assert (submesh.interior_facets.facets == np.array([])).all + assert (submesh.exterior_facets.facets == np.array([6, 7, 5, 8])).all() + composed_map, integral_type = mesh.topology.trans_mesh_entity_map(submesh.topology, "cell", None, None) + assert integral_type == "cell" + if rank == 0: + assert (composed_map.maps_[0].values_with_halo == np.array([0])).all() + else: + assert (composed_map.maps_[0].values_with_halo == np.array([1])).all() + composed_map, integral_type = mesh.topology.trans_mesh_entity_map(submesh.topology, "exterior_facet", 5, None) + assert integral_type == "interior_facet" + if rank == 0: + assert (composed_map.maps_[0].values_with_halo == np.array([-1, -1, 0, -1]).reshape((-1, 1))).all() # entire exterior-interior map + else: + assert (composed_map.maps_[0].values_with_halo == np.array([-1, -1, 0, -1]).reshape((-1, 1))).all() # entire exterior-interior map + composed_map, integral_type = mesh.topology.trans_mesh_entity_map(submesh.topology, "exterior_facet", 4, None) + assert integral_type == "exterior_facet" + if rank == 0: + assert (composed_map.maps_[0].values_with_halo == np.array([0, 1, -1, 2]).reshape((-1, 1))).all() # entire exterior-exterior map + else: + assert (composed_map.maps_[0].values_with_halo == np.array([3, 4, -1, 5]).reshape((-1, 1))).all() # entire exterior-exterior map + composed_map, integral_type = submesh.topology.trans_mesh_entity_map(mesh.topology, "exterior_facet", 1, None) + assert integral_type == "exterior_facet" + if rank == 0: + assert (composed_map.maps_[0].values_with_halo == np.array([0, 1, 3, -1, -1, -1]).reshape((-1, 1))).all() + else: + assert (composed_map.maps_[0].values_with_halo == np.array([-1, -1, -1, 0, 1, 3]).reshape((-1, 1))).all() diff --git a/tests/firedrake/submesh/test_submesh_solve.py b/tests/firedrake/submesh/test_submesh_solve.py new file mode 100644 index 0000000000..0691646748 --- /dev/null +++ b/tests/firedrake/submesh/test_submesh_solve.py @@ -0,0 +1,460 @@ +import pytest +from os.path import abspath, dirname, join +import numpy as np +from firedrake import * + + +cwd = abspath(dirname(__file__)) + + +def _solve_helmholtz(mesh): + V = FunctionSpace(mesh, "CG", 1) + u = TrialFunction(V) + v = TestFunction(V) + x = SpatialCoordinate(mesh) + u_exact = sin(x[0]) * sin(x[1]) + f = Function(V).interpolate(2 * u_exact) + a = (inner(grad(u), grad(v)) + inner(u, v)) * dx + L = inner(f, v) * dx + bc = DirichletBC(V, u_exact, "on_boundary") + sol = Function(V) + solve(a == L, sol, bcs=[bc], solver_parameters={'ksp_type': 'preonly', + 'pc_type': 'lu'}) + return sqrt(assemble((sol - u_exact)**2 * dx)) + + +@pytest.mark.parallel(nprocs=4) +@pytest.mark.parametrize('nelem', [2, 4]) +@pytest.mark.parametrize('distribution_parameters', [None, {"overlap_type": (DistributedMeshOverlapType.NONE, 0)}]) +def test_submesh_solve_simple(nelem, distribution_parameters): + dim = 2 + # Compute reference error. + mesh = RectangleMesh(nelem, nelem * 2, 1., 1., quadrilateral=True, distribution_parameters=distribution_parameters) + error = _solve_helmholtz(mesh) + # Compute submesh error. + mesh = RectangleMesh(nelem * 2, nelem * 2, 2., 1., quadrilateral=True, distribution_parameters=distribution_parameters) + x, y = SpatialCoordinate(mesh) + DQ0 = FunctionSpace(mesh, "DQ", 0) + indicator_function = Function(DQ0).interpolate(conditional(x < 1., 1, 0)) + mesh.mark_entities(indicator_function, 999) + mesh = Submesh(mesh, dim, 999) + suberror = _solve_helmholtz(mesh) + assert abs(error - suberror) < 1e-15 + + +@pytest.mark.parallel(nprocs=3) +@pytest.mark.parametrize('dim', [2, 3]) +@pytest.mark.parametrize('simplex', [True, False]) +def test_submesh_solve_cell_cell_mixed_scalar(dim, simplex): + if dim == 2: + if simplex: + mesh = Mesh(join(cwd, "..", "..", "..", "docs", "notebooks/stokes-control.msh")) + bid = (1, 2, 3, 4, 5) + submesh_expr = lambda x: conditional(x[0] < 10., 1, 0) + solution_expr = lambda x: x[0] + x[1] + else: + mesh = Mesh(join(cwd, "..", "meshes", "unitsquare_unstructured_quadrilaterals.msh")) + HDivTrace0 = FunctionSpace(mesh, "HDiv Trace", 0) + x, y = SpatialCoordinate(mesh) + hdivtrace0x = Function(HDivTrace0).interpolate(conditional(And(x > .001, x < .999), 0, 1)) + hdivtrace0y = Function(HDivTrace0).interpolate(conditional(And(y > .001, y < .999), 0, 1)) + mesh = RelabeledMesh(mesh, [hdivtrace0x, hdivtrace0y], [111, 222]) + bid = (111, 222) + submesh_expr = lambda x: conditional(x[0] < .5, 1, 0) + solution_expr = lambda x: x[0] + x[1] + elif dim == 3: + if simplex: + nref = 3 + mesh = BoxMesh(2 ** nref, 2 ** nref, 2 ** nref, 1., 1., 1., hexahedral=False) + HDivTrace0 = FunctionSpace(mesh, "HDiv Trace", 0) + else: + mesh = Mesh(join(cwd, "..", "meshes", "cube_hex.msh")) + HDivTrace0 = FunctionSpace(mesh, "Q", 2) + x, y, z = SpatialCoordinate(mesh) + hdivtrace0x = Function(HDivTrace0).interpolate(conditional(And(x > .001, x < .999), 0, 1)) + hdivtrace0y = Function(HDivTrace0).interpolate(conditional(And(y > .001, y < .999), 0, 1)) + hdivtrace0z = Function(HDivTrace0).interpolate(conditional(And(z > .001, z < .999), 0, 1)) + mesh = RelabeledMesh(mesh, [hdivtrace0x, hdivtrace0y, hdivtrace0z], [111, 222, 333]) + bid = (111, 222, 333) + submesh_expr = lambda x: conditional(x[0] > .5, 1, 0) + solution_expr = lambda x: x[0] + x[1] + x[2] + else: + raise NotImplementedError + DG0 = FunctionSpace(mesh, "DG", 0) + submesh_function = Function(DG0).interpolate(submesh_expr(SpatialCoordinate(mesh))) + submesh_label = 999 + mesh.mark_entities(submesh_function, submesh_label) + subm = Submesh(mesh, dim, submesh_label) + V0 = FunctionSpace(mesh, "CG", 2) + V1 = FunctionSpace(subm, "CG", 3) + V = V0 * V1 + u = TrialFunction(V) + v = TestFunction(V) + u0, u1 = split(u) + v0, v1 = split(v) + dx0 = Measure("dx", domain=mesh, intersect_measures=(Measure("dx", subm),)) + dx1 = Measure("dx", domain=subm, intersect_measures=(Measure("dx", mesh),)) + a = inner(grad(u0), grad(v0)) * dx0 + inner(u0 - u1, v1) * dx1 + L = inner(Constant(0.), v1) * dx1 + g = Function(V0).interpolate(solution_expr(SpatialCoordinate(mesh))) + bc = DirichletBC(V.sub(0), g, bid) + solution = Function(V) + solve(a == L, solution, bcs=[bc]) + target = Function(V1).interpolate(solution_expr(SpatialCoordinate(subm))) + assert np.allclose(solution.subfunctions[1].dat.data_ro_with_halos, target.dat.data_ro_with_halos) + + +@pytest.mark.parallel(nprocs=3) +@pytest.mark.parametrize('dim', [2, 3]) +@pytest.mark.parametrize('simplex', [True, False]) +def test_submesh_solve_cell_cell_mixed_vector(dim, simplex): + if dim == 2: + if simplex: + mesh = Mesh(join(cwd, "..", "..", "..", "docs", "notebooks/stokes-control.msh")) + submesh_expr = lambda x: conditional(x[0] < 10., 1, 0) + elem0 = FiniteElement("RT", "triangle", 3) + elem1 = VectorElement("P", "triangle", 3) + else: + mesh = Mesh(join(cwd, "..", "meshes", "unitsquare_unstructured_quadrilaterals.msh")) + HDivTrace0 = FunctionSpace(mesh, "HDiv Trace", 0) + x, y = SpatialCoordinate(mesh) + hdivtrace0x = Function(HDivTrace0).interpolate(conditional(And(x > .001, x < .999), 0, 1)) + hdivtrace0y = Function(HDivTrace0).interpolate(conditional(And(y > .001, y < .999), 0, 1)) + mesh = RelabeledMesh(mesh, [hdivtrace0x, hdivtrace0y], [111, 222]) + submesh_expr = lambda x: conditional(x[0] < .5, 1, 0) + elem0 = FiniteElement("RTCF", "quadrilateral", 2) + elem1 = VectorElement("Q", "quadrilateral", 3) + elif dim == 3: + if simplex: + nref = 3 + mesh = BoxMesh(2 ** nref, 2 ** nref, 2 ** nref, 1., 1., 1., hexahedral=False) + x, y, z = SpatialCoordinate(mesh) + HDivTrace0 = FunctionSpace(mesh, "HDiv Trace", 0) + hdivtrace0x = Function(HDivTrace0).interpolate(conditional(And(x > .001, x < .999), 0, 1)) + hdivtrace0y = Function(HDivTrace0).interpolate(conditional(And(y > .001, y < .999), 0, 1)) + hdivtrace0z = Function(HDivTrace0).interpolate(conditional(And(z > .001, z < .999), 0, 1)) + mesh = RelabeledMesh(mesh, [hdivtrace0x, hdivtrace0y, hdivtrace0z], [111, 222, 333]) + submesh_expr = lambda x: conditional(x[0] > .5, 1, 0) + elem0 = FiniteElement("N1F", "tetrahedron", 3) + elem1 = VectorElement("P", "tetrahedron", 3) + else: + mesh = Mesh(join(cwd, "..", "meshes", "cube_hex.msh")) + HDivTrace0 = FunctionSpace(mesh, "Q", 2) + x, y, z = SpatialCoordinate(mesh) + hdivtrace0x = Function(HDivTrace0).interpolate(conditional(And(x > .001, x < .999), 0, 1)) + hdivtrace0y = Function(HDivTrace0).interpolate(conditional(And(y > .001, y < .999), 0, 1)) + hdivtrace0z = Function(HDivTrace0).interpolate(conditional(And(z > .001, z < .999), 0, 1)) + mesh = RelabeledMesh(mesh, [hdivtrace0x, hdivtrace0y, hdivtrace0z], [111, 222, 333]) + submesh_expr = lambda x: conditional(x[0] > .5, 1, 0) + elem0 = FiniteElement("NCF", "hexahedron", 2) + elem1 = VectorElement("Q", "hexahedron", 3) + with pytest.raises(NotImplementedError): + _ = FunctionSpace(mesh, elem0) + return + else: + raise NotImplementedError + DG0 = FunctionSpace(mesh, "DG", 0) + submesh_function = Function(DG0).interpolate(submesh_expr(SpatialCoordinate(mesh))) + submesh_label = 999 + mesh.mark_entities(submesh_function, submesh_label) + subm = Submesh(mesh, dim, submesh_label) + V0 = FunctionSpace(mesh, elem0) + V1 = FunctionSpace(subm, elem1) + V = V0 * V1 + u = TrialFunction(V) + v = TestFunction(V) + u0, u1 = split(u) + v0, v1 = split(v) + dx0 = Measure("dx", domain=mesh, intersect_measures=(Measure("dx", subm),)) + dx1 = Measure("dx", domain=subm, intersect_measures=(Measure("dx", mesh),)) + a = inner(u0, v0) * dx0 + inner(u0 - u1, v1) * dx1 + L = inner(SpatialCoordinate(mesh), v0) * dx0 + solution = Function(V) + solve(a == L, solution) + s0, s1 = split(solution) + x = SpatialCoordinate(subm) + assert assemble(inner(s1 - x, s1 - x) * dx1) < 1.e-20 + + +def _mixed_poisson_create_mesh_2d(nref, quadrilateral, submesh_region, label_submesh, label_submesh_compl): + # y + # | + # | + # 1.0 +--17---+--18---+ + # | | | + # 12 20 14 + # | | | + # 0.5 +--21---+--22---+ + # | | | + # 11 19 13 + # | | | + # 0.0 +--15---+--16---+----x + # + # 0.0 0.5 1.0 + mesh = UnitSquareMesh(2 ** nref, 2 ** nref, quadrilateral=quadrilateral) + eps = 1. / (2 ** nref) / 100. + x, y = SpatialCoordinate(mesh) + HDivTrace0 = FunctionSpace(mesh, "HDiv Trace", 0) + f11 = Function(HDivTrace0).interpolate(conditional(And(x < eps, y < .5), 1, 0)) + f12 = Function(HDivTrace0).interpolate(conditional(And(x < eps, y > .5), 1, 0)) + f13 = Function(HDivTrace0).interpolate(conditional(And(x > 1 - eps, y < .5), 1, 0)) + f14 = Function(HDivTrace0).interpolate(conditional(And(x > 1 - eps, y > .5), 1, 0)) + f15 = Function(HDivTrace0).interpolate(conditional(And(x < .5, y < eps), 1, 0)) + f16 = Function(HDivTrace0).interpolate(conditional(And(x > .5, y < eps), 1, 0)) + f17 = Function(HDivTrace0).interpolate(conditional(And(x < .5, y > 1 - eps), 1, 0)) + f18 = Function(HDivTrace0).interpolate(conditional(And(x > .5, y > 1 - eps), 1, 0)) + f19 = Function(HDivTrace0).interpolate(conditional(And(And(x > .5 - eps, x < .5 + eps), y < .5), 1, 0)) + f20 = Function(HDivTrace0).interpolate(conditional(And(And(x > .5 - eps, x < .5 + eps), y > .5), 1, 0)) + f21 = Function(HDivTrace0).interpolate(conditional(And(x < .5, And(y > .5 - eps, y < .5 + eps)), 1, 0)) + f22 = Function(HDivTrace0).interpolate(conditional(And(x > .5, And(y > .5 - eps, y < .5 + eps)), 1, 0)) + DG0 = FunctionSpace(mesh, "DG", 0) + if submesh_region == "left": + submesh_function = Function(DG0).interpolate(conditional(x < .5, 1, 0)) + submesh_function_compl = Function(DG0).interpolate(conditional(x > .5, 1, 0)) + elif submesh_region == "right": + submesh_function = Function(DG0).interpolate(conditional(x > .5, 1, 0)) + submesh_function_compl = Function(DG0).interpolate(conditional(x < .5, 1, 0)) + elif submesh_region == "bottom": + submesh_function = Function(DG0).interpolate(conditional(y < .5, 1, 0)) + submesh_function_compl = Function(DG0).interpolate(conditional(y > .5, 1, 0)) + elif submesh_region == "top": + submesh_function = Function(DG0).interpolate(conditional(y > .5, 1, 0)) + submesh_function_compl = Function(DG0).interpolate(conditional(y < .5, 1, 0)) + else: + raise NotImplementedError(f"Unknown submesh_region: {submesh_region}") + return RelabeledMesh(mesh, [f11, f12, f13, f14, f15, f16, f17, f18, f19, f20, f21, f22, submesh_function, submesh_function_compl], + [11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, label_submesh, label_submesh_compl]) + + +def _mixed_poisson_solve_2d(nref, degree, quadrilateral, submesh_region): + dim = 2 + label_submesh = 999 + label_submesh_compl = 888 + mesh = _mixed_poisson_create_mesh_2d(nref, quadrilateral, submesh_region, label_submesh, label_submesh_compl) + x, y = SpatialCoordinate(mesh) + subm = Submesh(mesh, dim, label_submesh) + subx, suby = SpatialCoordinate(subm) + if submesh_region == "left": + boun_ext = (11, 12) + boun_int = (19, 20) + boun_dirichlet = (15, 17) + elif submesh_region == "right": + boun_ext = (13, 14) + boun_int = (19, 20) + boun_dirichlet = (16, 18) + elif submesh_region == "bottom": + boun_ext = (15, 16) + boun_int = (21, 22) + boun_dirichlet = (11, 13) + elif submesh_region == "top": + boun_ext = (17, 18) + boun_int = (21, 22) + boun_dirichlet = (12, 14) + else: + raise NotImplementedError(f"Unknown submesh_region: {submesh_region}") + BDM = FunctionSpace(subm, "RTCF" if quadrilateral else "BDM", degree) + DG = FunctionSpace(mesh, "DG", degree - 1) + W = BDM * DG + tau, v = TestFunctions(W) + nsub = FacetNormal(subm) + u_exact = Function(DG).interpolate(cos(2 * pi * x) * cos(2 * pi * y)) + sigma_exact = Function(BDM).project(as_vector([- 2 * pi * sin(2 * pi * subx) * cos(2 * pi * suby), - 2 * pi * cos(2 * pi * subx) * sin(2 * pi * suby)]), + solver_parameters={"ksp_type": "cg", "ksp_rtol": 1.e-16}) + f = Function(DG).interpolate(- 8 * pi * pi * cos(2 * pi * x) * cos(2 * pi * y)) + dx0 = Measure("dx", domain=mesh, intersect_measures=(Measure("dx", subm),)) + dx1 = Measure("dx", domain=subm, intersect_measures=(Measure("dx", mesh),)) + ds0 = Measure("ds", domain=mesh, intersect_measures=(Measure("ds", subm),)) + ds1_ext = Measure("ds", domain=subm, intersect_measures=(Measure("ds", mesh),)) + ds1_int = Measure("ds", domain=subm, intersect_measures=(Measure("dS", mesh),)) + dS0 = Measure("dS", domain=mesh, intersect_measures=(Measure("ds", subm),)) + bc = DirichletBC(W.sub(0), sigma_exact, boun_dirichlet) + # Do the base case. + w = Function(W) + sigma, u = split(w) + a = (inner(sigma, tau) + inner(u, div(tau)) + inner(div(sigma), v)) * dx1 + inner(u - u_exact, v) * dx0(label_submesh_compl) + L = inner(f, v) * dx1 + inner((u('+') + u('-')) / 2., dot(tau, nsub)) * dS0(boun_int) + inner(u_exact, dot(tau, nsub)) * ds0(boun_ext) + solve(a - L == 0, w, bcs=[bc]) + # Change domains of integration. + w_ = Function(W) + sigma_, u_ = split(w_) + a_ = (inner(sigma_, tau) + inner(u_, div(tau)) + inner(div(sigma_), v)) * dx1 + inner(u_ - u_exact, v) * dx0(label_submesh_compl) + L_ = inner(f, v) * dx0(label_submesh) + inner((u_('+') + u_('-')) / 2., dot(tau, nsub)) * ds1_int(boun_int) + inner(u_exact, dot(tau, nsub)) * ds1_ext(boun_ext) + solve(a_ - L_ == 0, w_, bcs=[bc]) + assert assemble(inner(sigma_ - sigma, sigma_ - sigma) * dx1) < 1.e-20 + assert assemble(inner(u_ - u, u_ - u) * dx0(label_submesh)) < 1.e-20 + sigma_error = sqrt(assemble(inner(sigma - sigma_exact, sigma - sigma_exact) * dx1)) + u_error = sqrt(assemble(inner(u - u_exact, u - u_exact) * dx0(label_submesh))) + return sigma_error, u_error + + +@pytest.mark.parallel(nprocs=4) +@pytest.mark.parametrize('nref', [1, 2, 3, 4]) +@pytest.mark.parametrize('degree', [1]) +@pytest.mark.parametrize('quadrilateral', [False, True]) +@pytest.mark.parametrize('submesh_region', ["left", "right", "bottom", "top"]) +def test_submesh_solve_mixed_poisson_check_sanity_2d(nref, degree, quadrilateral, submesh_region): + _, _ = _mixed_poisson_solve_2d(nref, degree, quadrilateral, submesh_region) + + +@pytest.mark.parallel(nprocs=4) +@pytest.mark.parametrize('quadrilateral', [True]) +@pytest.mark.parametrize('degree', [3]) +@pytest.mark.parametrize('submesh_region', ["left", "right"]) +def test_submesh_solve_mixed_poisson_check_convergence_2d(quadrilateral, degree, submesh_region): + nrefs = [5, 6, 7] + start = nrefs[0] + s_error_array = np.zeros(len(nrefs)) + u_error_array = np.zeros(len(nrefs)) + for nref in nrefs: + i = nref - start + s_error_array[i], u_error_array[i] = _mixed_poisson_solve_2d(nref, degree, quadrilateral, submesh_region) + assert (np.log2(s_error_array[:-1] / s_error_array[1:]) > degree + .95).all() + assert (np.log2(u_error_array[:-1] / u_error_array[1:]) > degree + .95).all() + + +def _mixed_poisson_create_mesh_3d(hexahedral, submesh_region, label_submesh, label_submesh_compl): + if hexahedral: + mesh = Mesh(join(cwd, "..", "meshes", "cube_hex.msh")) + DG0 = FunctionSpace(mesh, "DQ", 0) + HDivTrace0 = FunctionSpace(mesh, "Q", 2) + else: + mesh = BoxMesh(4, 4, 4, 1., 1., 1., hexahedral=False) + DG0 = FunctionSpace(mesh, "DP", 0) + HDivTrace0 = FunctionSpace(mesh, "HDiv Trace", 0) + x, y, z = SpatialCoordinate(mesh) + eps = 1.e-6 + f101 = Function(HDivTrace0).interpolate(conditional(x < eps, 1, 0)) + f102 = Function(HDivTrace0).interpolate(conditional(x > 1. - eps, 1, 0)) + f103 = Function(HDivTrace0).interpolate(conditional(y < eps, 1, 0)) + f104 = Function(HDivTrace0).interpolate(conditional(y > 1. - eps, 1, 0)) + f105 = Function(HDivTrace0).interpolate(conditional(z < eps, 1, 0)) + f106 = Function(HDivTrace0).interpolate(conditional(z > 1. - eps, 1, 0)) + if submesh_region == "left": + submesh_function = Function(DG0).interpolate(conditional(x < .5, 1, 0)) + submesh_function_compl = Function(DG0).interpolate(conditional(x > .5, 1, 0)) + elif submesh_region == "right": + submesh_function = Function(DG0).interpolate(conditional(x > .5, 1, 0)) + submesh_function_compl = Function(DG0).interpolate(conditional(x < .5, 1, 0)) + elif submesh_region == "front": + submesh_function = Function(DG0).interpolate(conditional(y < .5, 1, 0)) + submesh_function_compl = Function(DG0).interpolate(conditional(y > .5, 1, 0)) + elif submesh_region == "back": + submesh_function = Function(DG0).interpolate(conditional(y > .5, 1, 0)) + submesh_function_compl = Function(DG0).interpolate(conditional(y < .5, 1, 0)) + elif submesh_region == "bottom": + submesh_function = Function(DG0).interpolate(conditional(z < .5, 1, 0)) + submesh_function_compl = Function(DG0).interpolate(conditional(z > .5, 1, 0)) + elif submesh_region == "top": + submesh_function = Function(DG0).interpolate(conditional(z > .5, 1, 0)) + submesh_function_compl = Function(DG0).interpolate(conditional(z < .5, 1, 0)) + else: + raise NotImplementedError(f"Unknown submesh_region: {submesh_region}") + return RelabeledMesh(mesh, [f101, f102, f103, f104, f105, f106, submesh_function, submesh_function_compl], + [101, 102, 103, 104, 105, 106, label_submesh, label_submesh_compl]) + + +def _mixed_poisson_solve_3d(hexahedral, degree, submesh_region): + dim = 3 + label_submesh = 999 + label_submesh_compl = 888 + mesh = _mixed_poisson_create_mesh_3d(hexahedral, submesh_region, label_submesh, label_submesh_compl) + x, y, z = SpatialCoordinate(mesh) + subm = Submesh(mesh, dim, label_submesh) + subx, suby, subz = SpatialCoordinate(subm) + if submesh_region == "left": + boun_ext = (101, ) + boun_dirichlet = (103, 104, 105, 106) + elif submesh_region == "right": + boun_ext = (102, ) + boun_dirichlet = (103, 104, 105, 106) + elif submesh_region == "front": + boun_ext = (103, ) + boun_dirichlet = (101, 102, 105, 106) + elif submesh_region == "back": + boun_ext = (104, ) + boun_dirichlet = (101, 102, 105, 106) + elif submesh_region == "bottom": + boun_ext = (105, ) + boun_dirichlet = (101, 102, 103, 104) + elif submesh_region == "top": + boun_ext = (106, ) + boun_dirichlet = (101, 102, 103, 104) + else: + raise NotImplementedError(f"Unknown submesh_region: {submesh_region}") + boun_int = (107, ) # labeled automatically. + NCF = FunctionSpace(subm, "NCF" if hexahedral else "N2F", degree) + DG = FunctionSpace(mesh, "DG", degree - 1) + W = NCF * DG + tau, v = TestFunctions(W) + nsub = FacetNormal(subm) + u_exact = Function(DG).interpolate(cos(2 * pi * x) * cos(2 * pi * y) * cos(2 * pi * z)) + sigma_exact = Function(NCF).project(as_vector([- 2 * pi * sin(2 * pi * subx) * cos(2 * pi * suby) * cos(2 * pi * subz), + - 2 * pi * cos(2 * pi * subx) * sin(2 * pi * suby) * cos(2 * pi * subz), + - 2 * pi * cos(2 * pi * subx) * cos(2 * pi * suby) * sin(2 * pi * subz)]), + solver_parameters={"ksp_type": "cg", "ksp_rtol": 1.e-16}) + f = Function(DG).interpolate(- 12 * pi * pi * cos(2 * pi * x) * cos(2 * pi * y) * cos(2 * pi * z)) + dx0 = Measure("dx", domain=mesh, intersect_measures=(Measure("dx", subm),)) + dx1 = Measure("dx", domain=subm, intersect_measures=(Measure("dx", mesh),)) + ds0 = Measure("ds", domain=mesh, intersect_measures=(Measure("ds", subm),)) + ds1 = Measure("ds", domain=subm, intersect_measures=(Measure("dS", mesh),)) + bc = DirichletBC(W.sub(0), sigma_exact, boun_dirichlet) + # Do the base case. + w = Function(W) + sigma, u = split(w) + a = (inner(sigma, tau) + inner(u, div(tau)) + inner(div(sigma), v)) * dx1 + inner(u - u_exact, v) * dx0(label_submesh_compl) + L = inner(f, v) * dx1 + inner((u('+') + u('-')) / 2., dot(tau, nsub)) * ds1(boun_int) + inner(u_exact, dot(tau, nsub)) * ds0(boun_ext) + solve(a - L == 0, w, bcs=[bc]) + sigma_error = sqrt(assemble(inner(sigma - sigma_exact, sigma - sigma_exact) * dx1)) + u_error = sqrt(assemble(inner(u - u_exact, u - u_exact) * dx0(label_submesh))) + return sigma_error, u_error + + +@pytest.mark.parallel(nprocs=4) +@pytest.mark.parametrize('hexahedral', [False]) +@pytest.mark.parametrize('degree', [4]) +@pytest.mark.parametrize('submesh_region', ["left", "right", "front", "back", "bottom", "top"]) +def test_submesh_solve_mixed_poisson_check_sanity_3d(hexahedral, degree, submesh_region): + sigma_error, u_error = _mixed_poisson_solve_3d(hexahedral, degree, submesh_region) + assert sigma_error < 0.07 + assert u_error < 0.003 + + +@pytest.mark.parallel(nprocs=4) +@pytest.mark.parametrize('simplex', [True, False]) +@pytest.mark.parametrize('nref', [1, 3]) +@pytest.mark.parametrize('degree', [2, 4]) +def test_submesh_solve_cell_cell_equation_bc(nref, degree, simplex): + dim = 2 + mesh = RectangleMesh(3 ** nref, 2 ** nref, 3., 2., quadrilateral=not simplex) + x, y = SpatialCoordinate(mesh) + label_outer = 101 + label_inner = 100 + label_interface = 5 # automatically labeled by Submesh + DG0 = FunctionSpace(mesh, "DG", 0) + f_outer = Function(DG0).interpolate(conditional(Or(Or(x < 1., x > 2.), y > 1.), 1, 0)) + f_inner = Function(DG0).interpolate(conditional(And(And(x > 1., x < 2.), y < 1.), 1, 0)) + mesh = RelabeledMesh(mesh, [f_outer, f_inner], [label_outer, label_inner]) + x, y = SpatialCoordinate(mesh) + mesh_outer = Submesh(mesh, dim, label_outer) + x_outer, y_outer = SpatialCoordinate(mesh_outer) + mesh_inner = Submesh(mesh, dim, label_inner) + x_inner, y_inner = SpatialCoordinate(mesh_inner) + V_outer = FunctionSpace(mesh_outer, "CG", degree) + V_inner = FunctionSpace(mesh_inner, "CG", degree) + V = V_outer * V_inner + u = TrialFunction(V) + v = TestFunction(V) + sol = Function(V) + u_outer, u_inner = split(u) + v_outer, v_inner = split(v) + dx_outer = Measure("dx", domain=mesh_outer, intersect_measures=(Measure("dx", mesh), Measure("dx", mesh_inner))) + dx_inner = Measure("dx", domain=mesh_inner, intersect_measures=(Measure("dx", mesh), Measure("dx", mesh_outer))) + ds_outer = Measure("ds", domain=mesh_outer, intersect_measures=(Measure("ds", mesh_inner),)) + a = inner(grad(u_outer), grad(v_outer)) * dx_outer + \ + inner(u_inner, v_inner) * dx_inner + L = inner(x * y, v_inner) * dx_inner + dbc = DirichletBC(V.sub(0), x_outer * y_outer, (1, 2, 3, 4)) + ebc = EquationBC(inner(u_outer - u_inner, v_outer) * ds_outer(label_interface) == inner(Constant(0.), v_outer) * ds_outer(label_interface), sol, label_interface, V=V.sub(0)) + solve(a == L, sol, bcs=[dbc, ebc]) + assert sqrt(assemble(inner(sol[0] - x * y, sol[0] - x * y) * dx_outer)) < 1.e-12 + assert sqrt(assemble(inner(sol[1] - x * y, sol[1] - x * y) * dx_inner)) < 1.e-12 diff --git a/tests/tsfc/test_tsfc_182.py b/tests/tsfc/test_tsfc_182.py index 556a6bafb0..a6208a491c 100644 --- a/tests/tsfc/test_tsfc_182.py +++ b/tests/tsfc/test_tsfc_182.py @@ -1,6 +1,6 @@ import pytest -from ufl import Coefficient, TestFunction, dx, inner, tetrahedron, Mesh, FunctionSpace +from ufl import Coefficient, TestFunction, dx, inner, tetrahedron, Mesh, MeshSequence, FunctionSpace from finat.ufl import FiniteElement, MixedElement, VectorElement from tsfc import compile_form @@ -20,7 +20,8 @@ def test_delta_elimination(mode): element_chi_lambda = MixedElement(element_eps_p, element_lambda) domain = Mesh(VectorElement("Lagrange", tetrahedron, 1)) - space = FunctionSpace(domain, element_chi_lambda) + domains = MeshSequence([domain, domain]) + space = FunctionSpace(domains, element_chi_lambda) chi_lambda = Coefficient(space) delta_chi_lambda = TestFunction(space) diff --git a/tests/tsfc/test_tsfc_204.py b/tests/tsfc/test_tsfc_204.py index 89f1481590..fe889e48b8 100644 --- a/tests/tsfc/test_tsfc_204.py +++ b/tests/tsfc/test_tsfc_204.py @@ -1,12 +1,13 @@ from tsfc import compile_form from ufl import (Coefficient, FacetNormal, - FunctionSpace, Mesh, as_matrix, + FunctionSpace, Mesh, MeshSequence, as_matrix, dot, dS, ds, dx, facet, grad, inner, outer, split, triangle) from finat.ufl import BrokenElement, FiniteElement, MixedElement, VectorElement def test_physically_mapped_facet(): mesh = Mesh(VectorElement("P", triangle, 1)) + meshes = MeshSequence([mesh, mesh, mesh, mesh, mesh]) # set up variational problem U = FiniteElement("Morley", mesh.ufl_cell(), 2) @@ -15,7 +16,7 @@ def test_physically_mapped_facet(): Vv = VectorElement(BrokenElement(V)) Qhat = VectorElement(BrokenElement(V[facet]), dim=2) Vhat = VectorElement(V[facet], dim=2) - Z = FunctionSpace(mesh, MixedElement(U, Vv, Qhat, Vhat, R)) + Z = FunctionSpace(meshes, MixedElement(U, Vv, Qhat, Vhat, R)) z = Coefficient(Z) u, d, qhat, dhat, lam = split(z) diff --git a/tsfc/driver.py b/tsfc/driver.py index 89db890f24..eda810bdf7 100644 --- a/tsfc/driver.py +++ b/tsfc/driver.py @@ -10,7 +10,7 @@ from ufl.algorithms.analysis import has_type from ufl.algorithms.apply_coefficient_split import CoefficientSplitter from ufl.classes import Form, GeometricQuantity -from ufl.domain import extract_unique_domain +from ufl.domain import extract_unique_domain, extract_domains import gem import gem.impero_utils as impero_utils @@ -28,7 +28,7 @@ TSFCIntegralDataInfo = collections.namedtuple("TSFCIntegralDataInfo", - ["domain", "integral_type", "subdomain_id", "domain_number", + ["domain", "integral_type", "subdomain_id", "domain_number", "domain_integral_type_map", "arguments", "coefficients", "coefficient_split", "coefficient_numbers"]) TSFCIntegralDataInfo.__doc__ = """ @@ -93,9 +93,10 @@ def compile_form(form, prefix="form", parameters=None, dont_split_numbers=(), di kernels = [] for integral_data in form_data.integral_data: start = time.time() - kernel = compile_integral(integral_data, form_data, prefix, parameters, diagonal=diagonal) - if kernel is not None: - kernels.append(kernel) + if integral_data.integrals: + kernel = compile_integral(integral_data, form_data, prefix, parameters, diagonal=diagonal) + if kernel is not None: + kernels.append(kernel) logger.info(GREEN % "compile_integral finished in %g seconds.", time.time() - start) logger.info(GREEN % "TSFC finished in %g seconds.", time.time() - cpu_time) @@ -115,14 +116,10 @@ def compile_integral(integral_data, form_data, prefix, parameters, *, diagonal=F parameters = preprocess_parameters(parameters) scalar_type = parameters["scalar_type"] integral_type = integral_data.integral_type - mesh = integral_data.domain arguments = form_data.preprocessed_form.arguments() if integral_type.startswith("interior_facet") and diagonal and any(a.function_space().finat_element.is_dg() for a in arguments): raise NotImplementedError("Sorry, we can't assemble the diagonal of a form for interior facet integrals") kernel_name = f"{prefix}_{integral_type}_integral" - # Dict mapping domains to index in original_form.ufl_domains() - domain_numbering = form_data.original_form.domain_numbering() - domain_number = domain_numbering[integral_data.domain] # This is which coefficient in the original form the # current coefficient is. # Consider f*v*dx + g*v*ds, the full form contains two @@ -137,11 +134,15 @@ def compile_integral(integral_data, form_data, prefix, parameters, *, diagonal=F if coeff in form_data.coefficient_split: coefficient_split[coeff] = form_data.coefficient_split[coeff] coefficient_numbers.append(form_data.original_coefficient_positions[i]) + mesh = integral_data.domain + all_meshes = extract_domains(form_data.original_form) + domain_number = all_meshes.index(mesh) integral_data_info = TSFCIntegralDataInfo( domain=integral_data.domain, integral_type=integral_data.integral_type, subdomain_id=integral_data.subdomain_id, domain_number=domain_number, + domain_integral_type_map={mesh: integral_data.domain_integral_type_map[mesh] if mesh in integral_data.domain_integral_type_map else None for mesh in all_meshes}, arguments=arguments, coefficients=coefficients, coefficient_split=coefficient_split, @@ -152,8 +153,11 @@ def compile_integral(integral_data, form_data, prefix, parameters, *, diagonal=F scalar_type, diagonal=diagonal, ) - builder.set_coordinates(mesh) - builder.set_cell_sizes(mesh) + builder.set_entity_numbers(all_meshes) + builder.set_entity_orientations(all_meshes) + builder.set_coordinates(all_meshes) + builder.set_cell_orientations(all_meshes) + builder.set_cell_sizes(all_meshes) builder.set_coefficients() # TODO: We do not want pass constants to kernels that do not need them # so we should attach the constants to integral data instead @@ -242,11 +246,16 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *, if domain is None: domain = extract_unique_domain(expression) assert domain is not None + builder._domain_integral_type_map = {domain: "cell"} # Collect required coefficients and determine numbering coefficients = extract_coefficients(expression) coefficient_numbers = tuple(map(orig_coefficients.index, coefficients)) builder.set_coefficient_numbers(coefficient_numbers) + # Need this ad-hoc fix for now. + for c in coefficients: + d = extract_unique_domain(c) + builder._domain_integral_type_map[d] = "cell" elements = [f.ufl_element() for f in (*coefficients, *arguments)] @@ -255,7 +264,8 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *, # Create a fake coordinate coefficient for a domain. coords_coefficient = ufl.Coefficient(ufl.FunctionSpace(domain, domain.ufl_coordinate_element())) builder.domain_coordinate[domain] = coords_coefficient - builder.set_cell_sizes(domain) + builder.set_cell_orientations((domain, )) + builder.set_cell_sizes((domain, )) coefficients = [coords_coefficient] + coefficients needs_external_coords = True builder.set_coefficients(coefficients) @@ -272,7 +282,6 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *, ufl_cell=domain.ufl_cell(), # FIXME: change if we ever implement # interpolation on facets. - integral_type="cell", argument_multiindices=argument_multiindices, index_cache={}, scalar_type=parameters["scalar_type"]) diff --git a/tsfc/fem.py b/tsfc/fem.py index 9166b4b8f0..0399c5a421 100644 --- a/tsfc/fem.py +++ b/tsfc/fem.py @@ -31,6 +31,7 @@ from ufl.corealg.map_dag import map_expr_dag, map_expr_dags from ufl.corealg.multifunction import MultiFunction from ufl.domain import extract_unique_domain +from ufl.algorithms import extract_arguments from tsfc import ufl2gem from tsfc.kernel_interface import ProxyKernelInterface @@ -50,7 +51,6 @@ class ContextBase(ProxyKernelInterface): keywords = ( 'ufl_cell', 'fiat_cell', - 'integral_type', 'integration_dim', 'entity_ids', 'argument_multiindices', @@ -86,7 +86,7 @@ def epsilon(self): def complex_mode(self): return is_complex(self.scalar_type) - def entity_selector(self, callback, restriction): + def entity_selector(self, callback, domain, restriction): """Selects code for the correct entity at run-time. Callback generates code for a specified entity. @@ -100,7 +100,7 @@ def entity_selector(self, callback, restriction): if len(self.entity_ids) == 1: return callback(self.entity_ids[0]) else: - f = self.entity_number(restriction) + f = self.entity_number(domain, restriction) return gem.select_expression(list(map(callback, self.entity_ids)), f) argument_multiindices = () @@ -119,7 +119,19 @@ def use_canonical_quadrature_point_ordering(self): # Directly set use_canonical_quadrature_point_ordering = False in context # for translation of special nodes, e.g., CellVolume, FacetArea, CellOrigin, and CellVertices, # as quadrature point ordering is not relevant for those node types. - return isinstance(self.fiat_cell, UFCHexahedron) and self.integral_type in ['exterior_facet', 'interior_facet'] + cell_integral_type_map = { + as_fiat_cell(domain.ufl_cell()): integral_type + for domain, integral_type in self.domain_integral_type_map.items() + if integral_type is not None + } + if all(integral_type == 'cell' for integral_type in cell_integral_type_map.values()): + return False + elif all(integral_type in ['exterior_facet', 'interior_facet'] for integral_type in cell_integral_type_map.values()): + if all(isinstance(cell, UFCHexahedron) for cell in cell_integral_type_map): + return True + elif len(set(cell_integral_type_map)) > 1: # mixed cell types + return True + return False class CoordinateMapping(PhysicalGeometry): @@ -144,7 +156,8 @@ def preprocess(self, expr, context): :arg context: The translation context. :returns: A new UFL expression """ - ifacet = self.interface.integral_type.startswith("interior_facet") + domain = extract_unique_domain(self.mt.terminal) + ifacet = self.interface.domain_integral_type_map[domain].startswith("interior_facet") return preprocess_expression(expr, complex_mode=context.complex_mode, do_apply_restrictions=ifacet) @@ -173,7 +186,7 @@ def translate_point_expression(self, expr, point=None): return map_expr_dag(context.translator, expr) def cell_size(self): - return self.interface.cell_size(self.mt.restriction) + return self.interface.cell_size(extract_unique_domain(self.mt.terminal), self.mt.restriction) def jacobian_at(self, point): expr = Jacobian(extract_unique_domain(self.mt.terminal)) @@ -275,7 +288,9 @@ def make_basis_evaluation_key(ctx, finat_element, mt, entity_id): ufl_element = mt.terminal.ufl_element() domain = extract_unique_domain(mt.terminal) coordinate_element = domain.ufl_coordinate_element() - return (ufl_element, mt.local_derivatives, ctx.point_set, ctx.integration_dim, entity_id, coordinate_element, mt.restriction) + # This way of caching is fragile. + # Should Implement _hash_key_() for ModifiedTerminal and use the entire mt as key. + return (ufl_element, mt.local_derivatives, ctx.point_set, ctx.integration_dim, entity_id, coordinate_element, mt.restriction, domain._ufl_hash_data_()) class PointSetContext(ContextBase): @@ -351,14 +366,15 @@ def __init__(self, context): # Can't put these in the ufl2gem mixin, since they (unlike # everything else) want access to the translation context. def cell_avg(self, o): - if self.context.integral_type != "cell": + domain = extract_unique_domain(o) + integral_type = self.context.domain_integral_type_map[domain] + if integral_type != "cell": # Need to create a cell-based quadrature rule and # translate the expression using that (c.f. CellVolume # below). raise NotImplementedError("CellAvg on non-cell integrals not yet implemented") integrand, = o.ufl_operands - domain = extract_unique_domain(o) - measure = ufl.Measure(self.context.integral_type, domain=domain) + measure = ufl.Measure(integral_type, domain=domain) integrand, degree, argument_multiindices = entity_avg(integrand / CellVolume(domain), measure, self.context.argument_multiindices) config = {name: getattr(self.context, name) @@ -369,17 +385,17 @@ def cell_avg(self, o): return expr def facet_avg(self, o): - if self.context.integral_type == "cell": + domain = extract_unique_domain(o) + integral_type = self.context.domain_integral_type_map[domain] + if integral_type == "cell": raise ValueError("Can't take FacetAvg in cell integral") integrand, = o.ufl_operands - domain = extract_unique_domain(o) - measure = ufl.Measure(self.context.integral_type, domain=domain) + measure = ufl.Measure(integral_type, domain=domain) integrand, degree, argument_multiindices = entity_avg(integrand / FacetArea(domain), measure, self.context.argument_multiindices) config = {name: getattr(self.context, name) for name in ["ufl_cell", "index_cache", "scalar_type", - "integration_dim", "entity_ids", - "integral_type"]} + "integration_dim", "entity_ids"]} config.update(quadrature_degree=degree, interface=self.context, argument_multiindices=argument_multiindices) expr, = compile_ufl(integrand, PointSetContext(**config), point_sum=True) @@ -416,7 +432,7 @@ def translate_geometricquantity(terminal, mt, ctx): @translate.register(CellOrientation) def translate_cell_orientation(terminal, mt, ctx): - return ctx.cell_orientation(mt.restriction) + return ctx.cell_orientation(extract_unique_domain(terminal), mt.restriction) @translate.register(ReferenceCellVolume) @@ -426,7 +442,7 @@ def translate_reference_cell_volume(terminal, mt, ctx): @translate.register(ReferenceFacetVolume) def translate_reference_facet_volume(terminal, mt, ctx): - assert ctx.integral_type != "cell" + assert ctx.domain_integral_type_map[extract_unique_domain(terminal)] != "cell" # Sum of quadrature weights is entity volume return gem.optimise.aggressive_unroll(gem.index_sum(ctx.weight_expr, ctx.point_indices)) @@ -440,7 +456,7 @@ def translate_cell_facet_jacobian(terminal, mt, ctx): def callback(entity_id): return gem.Literal(make_cell_facet_jacobian(cell, facet_dim, entity_id)) - return ctx.entity_selector(callback, mt.restriction) + return ctx.entity_selector(callback, extract_unique_domain(terminal), mt.restriction) def make_cell_facet_jacobian(cell, facet_dim, facet_i): @@ -465,7 +481,7 @@ def translate_reference_normal(terminal, mt, ctx): def callback(facet_i): n = ctx.fiat_cell.compute_reference_normal(ctx.integration_dim, facet_i) return gem.Literal(n) - return ctx.entity_selector(callback, mt.restriction) + return ctx.entity_selector(callback, extract_unique_domain(terminal), mt.restriction) @translate.register(ReferenceCellEdgeVectors) @@ -498,7 +514,7 @@ def callback(entity_id): data = numpy.asarray(list(map(t, ps.points))) return gem.Literal(data.reshape(point_shape + data.shape[1:])) - return gem.partial_indexed(ctx.entity_selector(callback, mt.restriction), + return gem.partial_indexed(ctx.entity_selector(callback, extract_unique_domain(terminal), mt.restriction), ps.indices) @@ -549,9 +565,10 @@ def translate_cellvolume(terminal, mt, ctx): @translate.register(FacetArea) def translate_facetarea(terminal, mt, ctx): - assert ctx.integral_type != 'cell' domain = extract_unique_domain(terminal) - integrand, degree = one_times(ufl.Measure(ctx.integral_type, domain=domain)) + integral_type = ctx.domain_integral_type_map[domain] + assert integral_type != 'cell' + integrand, degree = one_times(ufl.Measure(integral_type, domain=domain)) config = {name: getattr(ctx, name) for name in ["ufl_cell", "integration_dim", "scalar_type", @@ -649,10 +666,10 @@ def callback(entity_id): # A numerical hack that FFC used to apply on FIAT tables still # lives on after ditching FFC and switching to FInAT. return ffc_rounding(square, ctx.epsilon) - table = ctx.entity_selector(callback, mt.restriction) + table = ctx.entity_selector(callback, extract_unique_domain(terminal), mt.restriction) if ctx.use_canonical_quadrature_point_ordering: quad_multiindex = ctx.quadrature_rule.point_set.indices - quad_multiindex_permuted = _make_quad_multiindex_permuted(mt, ctx) + quad_multiindex_permuted = _make_quad_multiindex_permuted(terminal, mt, ctx) mapper = gem.node.MemoizerArg(gem.optimise.filtered_replace_indices) table = mapper(table, tuple(zip(quad_multiindex, quad_multiindex_permuted))) argument_multiindex = ctx.argument_multiindices[terminal.number()] @@ -695,7 +712,7 @@ def take_singleton(xs): per_derivative = {alpha: take_singleton(tables) for alpha, tables in per_derivative.items()} else: - f = ctx.entity_number(mt.restriction) + f = ctx.entity_number(extract_unique_domain(terminal), mt.restriction) per_derivative = {alpha: gem.select_expression(tables, f) for alpha, tables in per_derivative.items()} @@ -727,13 +744,13 @@ def take_singleton(xs): if ctx.use_canonical_quadrature_point_ordering: quad_multiindex = ctx.quadrature_rule.point_set.indices - quad_multiindex_permuted = _make_quad_multiindex_permuted(mt, ctx) + quad_multiindex_permuted = _make_quad_multiindex_permuted(terminal, mt, ctx) mapper = gem.node.MemoizerArg(gem.optimise.filtered_replace_indices) result = mapper(result, tuple(zip(quad_multiindex, quad_multiindex_permuted))) return result -def _make_quad_multiindex_permuted(mt, ctx): +def _make_quad_multiindex_permuted(terminal, mt, ctx): quad_rule = ctx.quadrature_rule # Note that each quad index here represents quad points on a physical # cell axis, but the table is indexed by indices representing the points @@ -746,7 +763,8 @@ def _make_quad_multiindex_permuted(mt, ctx): if len(extents) != 1: raise ValueError("Must have the same number of quadrature points in each symmetric axis") quad_multiindex_permuted = [] - o = ctx.entity_orientation(mt.restriction) + domain = extract_unique_domain(terminal) + o = ctx.entity_orientation(domain, mt.restriction) if not isinstance(o, FIATOrientation): raise ValueError(f"Expecting an instance of FIATOrientation : got {o}") eo = cell.extract_extrinsic_orientation(o) @@ -761,27 +779,23 @@ def _make_quad_multiindex_permuted(mt, ctx): return tuple(quad_multiindex_permuted) -def compile_ufl(expression, context, interior_facet=False, point_sum=False): +def compile_ufl(expression, context, point_sum=False): """Translate a UFL expression to GEM. :arg expression: The UFL expression to compile. :arg context: translation context - either a :class:`GemPointContext` or :class:`PointSetContext` - :arg interior_facet: If ``true``, treat expression as an interior - facet integral (default ``False``) :arg point_sum: If ``true``, return a `gem.IndexSum` of the final gem expression along the ``context.point_indices`` (if present). """ # Abs-simplification expression = simplify_abs(expression, context.complex_mode) - if interior_facet: - expressions = [] - for rs in itertools.product(("+", "-"), repeat=len(context.argument_multiindices)): - expressions.append(map_expr_dag(PickRestriction(*rs), expression)) - else: - expressions = [expression] - + arguments = extract_arguments(expression) + domains = [extract_unique_domain(argument) for argument in arguments] + integral_types = [context.domain_integral_type_map[domain] for domain in domains] + rs_tuples = [("+", "-") if integral_type.startswith("interior_facet") else (None, ) for integral_type in integral_types] + expressions = [map_expr_dag(PickRestriction(*rs), expression) for rs in itertools.product(*rs_tuples)] # Translate UFL to GEM, lowering finite element specific nodes result = map_expr_dags(context.translator, expressions) if point_sum: diff --git a/tsfc/kernel_args.py b/tsfc/kernel_args.py index a397f0f937..80b1bed77f 100644 --- a/tsfc/kernel_args.py +++ b/tsfc/kernel_args.py @@ -54,9 +54,9 @@ class InteriorFacetKernelArg(KernelArg): ... -class ExteriorFacetOrientationKernelArg(KernelArg): +class OrientationsExteriorFacetKernelArg(KernelArg): ... -class InteriorFacetOrientationKernelArg(KernelArg): +class OrientationsInteriorFacetKernelArg(KernelArg): ... diff --git a/tsfc/kernel_interface/__init__.py b/tsfc/kernel_interface/__init__.py index 5114263848..9b7419e35d 100644 --- a/tsfc/kernel_interface/__init__.py +++ b/tsfc/kernel_interface/__init__.py @@ -22,19 +22,19 @@ def constant(self, const): """Return the GEM expression corresponding to the constant.""" @abstractmethod - def cell_orientation(self, restriction): + def cell_orientation(self, domain, restriction): """Cell orientation as a GEM expression.""" @abstractmethod - def cell_size(self, restriction): + def cell_size(self, domain, restriction): """Mesh cell size as a GEM expression. Shape (nvertex, ) in FIAT vertex ordering.""" @abstractmethod - def entity_number(self, restriction): + def entity_number(self, domain, restriction): """Facet or vertex number as a GEM index.""" @abstractmethod - def entity_orientation(self, restriction): + def entity_orientation(self, domain, restriction): """Entity orientation as a GEM index.""" @abstractmethod @@ -47,5 +47,9 @@ def unsummed_coefficient_indices(self): """A set of indices that coefficient evaluation should not sum over. Used for macro-cell integration.""" + @abstractproperty + def domain_integral_type_map(self): + """domain integral_type map.""" + ProxyKernelInterface = make_proxy_class('ProxyKernelInterface', KernelInterface) diff --git a/tsfc/kernel_interface/common.py b/tsfc/kernel_interface/common.py index 53d3d96afa..d1bf7653db 100644 --- a/tsfc/kernel_interface/common.py +++ b/tsfc/kernel_interface/common.py @@ -3,6 +3,10 @@ import string from functools import cached_property, reduce from itertools import chain, product +import copy + +from ufl.utils.sequences import max_degree +from ufl.domain import extract_unique_domain import gem import gem.impero_utils as impero_utils @@ -20,20 +24,14 @@ from finat.ufl import MixedElement from tsfc.kernel_interface import KernelInterface from tsfc.logging import logger -from ufl.utils.sequences import max_degree class KernelBuilderBase(KernelInterface): """Helper class for building local assembly kernels.""" - def __init__(self, scalar_type, interior_facet=False): - """Initialise a kernel builder. - - :arg interior_facet: kernel accesses two cells - """ - assert isinstance(interior_facet, bool) + def __init__(self, scalar_type): + """Initialise a kernel builder.""" self.scalar_type = scalar_type - self.interior_facet = interior_facet self.prepare = [] self.finalise = [] @@ -58,9 +56,10 @@ def coefficient(self, ufl_coefficient, restriction): """A function that maps :class:`ufl.Coefficient`s to GEM expressions.""" kernel_arg = self.coefficient_map[ufl_coefficient] + domain = extract_unique_domain(ufl_coefficient) if ufl_coefficient.ufl_element().family() == 'Real': return kernel_arg - elif not self.interior_facet: + elif not self._domain_integral_type_map[domain].startswith("interior_facet"): return kernel_arg else: return kernel_arg[{'+': 0, '-': 1}[restriction]] @@ -68,34 +67,37 @@ def coefficient(self, ufl_coefficient, restriction): def constant(self, const): return self.constant_map[const] - def cell_orientation(self, restriction): + def cell_orientation(self, domain, restriction): """Cell orientation as a GEM expression.""" + if not hasattr(self, "_cell_orientations"): + raise RuntimeError("Haven't called set_cell_orientations") f = {None: 0, '+': 0, '-': 1}[restriction] - # Assume self._cell_orientations tuple is set up at this point. - co_int = self._cell_orientations[f] + co_int = self._cell_orientations[domain][f] return gem.Conditional(gem.Comparison("==", co_int, gem.Literal(1)), gem.Literal(-1), gem.Conditional(gem.Comparison("==", co_int, gem.Zero()), gem.Literal(1), gem.Literal(numpy.nan))) - def cell_size(self, restriction): + def cell_size(self, domain, restriction): if not hasattr(self, "_cell_sizes"): raise RuntimeError("Haven't called set_cell_sizes") - if self.interior_facet: - return self._cell_sizes[{'+': 0, '-': 1}[restriction]] + if self._domain_integral_type_map[domain].startswith("interior_facet"): + return self._cell_sizes[domain][{'+': 0, '-': 1}[restriction]] else: - return self._cell_sizes + return self._cell_sizes[domain] - def entity_number(self, restriction): + def entity_number(self, domain, restriction): """Facet or vertex number as a GEM index.""" - # Assume self._entity_number dict is set up at this point. - return self._entity_number[restriction] + if not hasattr(self, "_entity_numbers"): + raise RuntimeError("Haven't called set_entity_numbers") + return self._entity_numbers[domain][restriction] - def entity_orientation(self, restriction): + def entity_orientation(self, domain, restriction): """Facet orientation as a GEM index.""" - # Assume self._entity_orientation dict is set up at this point. - return self._entity_orientation[restriction] + if not hasattr(self, "_entity_orientations"): + raise RuntimeError("Haven't called set_entity_orientations") + return self._entity_orientations[domain][restriction] def apply_glue(self, prepare=None, finalise=None): """Append glue code for operations that are not handled in the @@ -120,6 +122,11 @@ def register_requirements(self, ir): # Nothing is required by default pass + @property + def domain_integral_type_map(self): + """domain integral_type map.""" + return self._domain_integral_type_map + class KernelBuilderMixin(object): """Mixin for KernelBuilder classes.""" @@ -143,8 +150,7 @@ def compile_integrand(self, integrand, params, ctx): config['quadrature_rule'] = quad_rule config['index_cache'] = ctx['index_cache'] expressions = fem.compile_ufl(integrand, - fem.PointSetContext(**config), - interior_facet=self.interior_facet) + fem.PointSetContext(**config)) ctx['quadrature_indices'].extend(quad_rule.point_set.indices) return expressions @@ -214,7 +220,7 @@ def compile_gem(self, ctx): # Let the kernel interface inspect the optimised IR to register # what kind of external data is required (e.g., cell orientations, # cell sizes, etc.). - oriented, needs_cell_sizes, tabulations, need_facet_orientation = self.register_requirements(expressions) + oriented, needs_cell_sizes, tabulations = self.register_requirements(expressions) # Extract Variables that are actually used active_variables = gem.extract_type(expressions, gem.Variable) @@ -225,7 +231,7 @@ def compile_gem(self, ctx): impero_c = impero_utils.compile_gem(assignments, index_ordering, remove_zeros=True) except impero_utils.NoopError: impero_c = None - return impero_c, oriented, needs_cell_sizes, tabulations, active_variables, need_facet_orientation + return impero_c, oriented, needs_cell_sizes, tabulations, active_variables def fem_config(self): """Return a dictionary used with fem.compile_ufl. @@ -241,7 +247,6 @@ def fem_config(self): integration_dim, entity_ids = lower_integral_type(fiat_cell, integral_type) return dict(interface=self, ufl_cell=cell, - integral_type=integral_type, integration_dim=integration_dim, entity_ids=entity_ids, scalar_type=self.fem_scalar_type) @@ -439,19 +444,16 @@ def check_requirements(ir): in one pass.""" cell_orientations = False cell_sizes = False - facet_orientation = False rt_tabs = {} for node in traversal(ir): if isinstance(node, gem.Variable): - if node.name == "cell_orientations": + if node.name == "cell_orientations_0": cell_orientations = True - elif node.name == "cell_sizes": + elif node.name == "cell_sizes_0": cell_sizes = True elif node.name.startswith("rt_"): rt_tabs[node.name] = node.shape - elif node.name == "facet_orientation": - facet_orientation = True - return cell_orientations, cell_sizes, tuple(sorted(rt_tabs.items())), facet_orientation + return cell_orientations, cell_sizes, tuple(sorted(rt_tabs.items())) def prepare_constant(constant, number): @@ -468,55 +470,74 @@ def prepare_constant(constant, number): constant.ufl_shape) -def prepare_coefficient(coefficient, name, interior_facet=False): +def prepare_coefficient(coefficient, name, domain_integral_type_map): """Bridges the kernel interface and the GEM abstraction for Coefficients. - :arg coefficient: UFL Coefficient - :arg name: unique name to refer to the Coefficient in the kernel - :arg interior_facet: interior facet integral? - :returns: (funarg, expression) - expression - GEM expression referring to the Coefficient - values - """ - assert isinstance(interior_facet, bool) + Parameters + ---------- + coefficient : ufl.Coefficient + UFL Coefficient. + name : str + Unique name to refer to the Coefficient in the kernel. + domain_integral_type_map : dict + Map from domain to integral_type. + + Returns + ------- + gem.Node + GEM expression referring to the Coefficient values. + """ if coefficient.ufl_element().family() == 'Real': # Constant value_size = coefficient.ufl_function_space().value_size expression = gem.reshape(gem.Variable(name, (value_size,)), coefficient.ufl_shape) return expression - finat_element = create_element(coefficient.ufl_element()) shape = finat_element.index_shape size = numpy.prod(shape, dtype=int) - - if not interior_facet: - expression = gem.reshape(gem.Variable(name, (size,)), shape) - else: + domain = extract_unique_domain(coefficient) + integral_type = domain_integral_type_map[domain] + if integral_type is None: + # This means that this coefficient does not exist in the DAG, + # so corresponding gem expression will never be needed. + expression = None + elif integral_type.startswith("interior_facet"): varexp = gem.Variable(name, (2 * size,)) plus = gem.view(varexp, slice(size)) minus = gem.view(varexp, slice(size, 2 * size)) expression = (gem.reshape(plus, shape), gem.reshape(minus, shape)) + else: + expression = gem.reshape(gem.Variable(name, (size,)), shape) return expression -def prepare_arguments(arguments, multiindices, interior_facet=False, diagonal=False): +def prepare_arguments(arguments, multiindices, domain_integral_type_map, diagonal=False): """Bridges the kernel interface and the GEM abstraction for Arguments. Vector Arguments are rearranged here for interior facet integrals. - :arg arguments: UFL Arguments - :arg multiindices: Argument multiindices - :arg interior_facet: interior facet integral? - :arg diagonal: Are we assembling the diagonal of a rank-2 element tensor? - :returns: (funarg, expression) - expressions - GEM expressions referring to the argument - tensor - """ - assert isinstance(interior_facet, bool) + Parameters + ---------- + arguments : tuple + UFL Arguments. + multiindices : tuple + Argument multiindices. + domain_integral_type_map : dict + Map from domain to integral_type. + diagonal : bool + Are we assembling the diagonal of a rank-2 element tensor? + + Returns + ------- + tuple + Tuple of function arg and GEM expressions referring to the argument tensor. + """ + if len(multiindices) != len(arguments): + raise ValueError(f"Got inconsistent lengths of arguments ({len(arguments)}) and multiindices ({len(multiindices)})") if len(arguments) == 0: # No arguments expression = gem.Indexed(gem.Variable("A", (1,)), (0,)) @@ -532,25 +553,30 @@ def prepare_arguments(arguments, multiindices, interior_facet=False, diagonal=Fa element, = set(elements) except ValueError: raise ValueError("Diagonal only for diagonal blocks (test and trial spaces the same)") - elements = (element, ) shapes = tuple(element.index_shape for element in elements) multiindices = multiindices[:1] + arguments = arguments[:1] def expression(restricted): return gem.Indexed(gem.reshape(restricted, *shapes), tuple(chain(*multiindices))) u_shape = numpy.array([numpy.prod(shape, dtype=int) for shape in shapes]) - if interior_facet: - c_shape = tuple(2 * u_shape) - slicez = [[slice(r * s, (r + 1) * s) - for r, s in zip(restrictions, u_shape)] - for restrictions in product((0, 1), repeat=len(arguments))] - else: - c_shape = tuple(u_shape) - slicez = [[slice(s) for s in u_shape]] - - varexp = gem.Variable("A", c_shape) + c_shape = copy.deepcopy(u_shape) + rs_tuples = [] + for arg_num, arg in enumerate(arguments): + integral_type = domain_integral_type_map[extract_unique_domain(arg)] + if integral_type is None: + raise RuntimeError(f"Can not determine integral_type on {arg}") + if integral_type.startswith("interior_facet"): + rs_tuples.append((0, 1)) + c_shape[arg_num] *= 2 + else: + rs_tuples.append((0, )) + slicez = [[slice(r * s, (r + 1) * s) + for r, s in zip(restrictions, u_shape)] + for restrictions in product(*rs_tuples)] + varexp = gem.Variable("A", tuple(c_shape)) expressions = [expression(gem.view(varexp, *slices)) for slices in slicez] return tuple(prune(expressions)) diff --git a/tsfc/kernel_interface/firedrake_loopy.py b/tsfc/kernel_interface/firedrake_loopy.py index f13a7d1e33..856b31761d 100644 --- a/tsfc/kernel_interface/firedrake_loopy.py +++ b/tsfc/kernel_interface/firedrake_loopy.py @@ -2,6 +2,8 @@ from collections import namedtuple, OrderedDict from ufl import Coefficient, FunctionSpace +from ufl.domain import MeshSequence + from finat.ufl import MixedElement as ufl_MixedElement, FiniteElement import gem @@ -23,16 +25,28 @@ 'flop_count', 'event']) +ActiveDomainNumbers = namedtuple('ActiveDomainNumbers', ['coordinates', + 'cell_orientations', + 'cell_sizes', + 'exterior_facets', + 'interior_facets', + 'orientations_exterior_facet', + 'orientations_interior_facet']) +ActiveDomainNumbers.__doc__ = """ + Active domain numbers collected for each key. + + """ + + class Kernel: - __slots__ = ("ast", "arguments", "integral_type", "oriented", "subdomain_id", - "domain_number", "needs_cell_sizes", "tabulations", + __slots__ = ("ast", "arguments", "integral_type", "subdomain_id", + "domain_number", "active_domain_numbers", "tabulations", "coefficient_numbers", "name", "flop_count", "event", "__weakref__") """A compiled Kernel object. :kwarg ast: The loopy kernel object. :kwarg integral_type: The type of integral. - :kwarg oriented: Does the kernel require cell_orientations. :kwarg subdomain_id: What is the subdomain id for this kernel. :kwarg domain_number: Which domain number in the original form does this kernel correspond to (can be used to index into @@ -40,15 +54,13 @@ class Kernel: :kwarg coefficient_numbers: A list of which coefficients from the form the kernel needs. :kwarg tabulations: The runtime tabulations this kernel requires - :kwarg needs_cell_sizes: Does the kernel require cell sizes. :kwarg name: The name of this kernel. :kwarg flop_count: Estimated total flops for this kernel. :kwarg event: name for logging event """ - def __init__(self, ast=None, arguments=None, integral_type=None, oriented=False, - subdomain_id=None, domain_number=None, + def __init__(self, ast=None, arguments=None, integral_type=None, + subdomain_id=None, domain_number=None, active_domain_numbers=None, coefficient_numbers=(), - needs_cell_sizes=False, tabulations=None, flop_count=0, name=None, @@ -57,11 +69,10 @@ def __init__(self, ast=None, arguments=None, integral_type=None, oriented=False, self.ast = ast self.arguments = arguments self.integral_type = integral_type - self.oriented = oriented self.domain_number = domain_number + self.active_domain_numbers = active_domain_numbers self.subdomain_id = subdomain_id self.coefficient_numbers = coefficient_numbers - self.needs_cell_sizes = needs_cell_sizes self.tabulations = tabulations self.flop_count = flop_count self.name = name @@ -70,21 +81,9 @@ def __init__(self, ast=None, arguments=None, integral_type=None, oriented=False, class KernelBuilderBase(_KernelBuilderBase): - def __init__(self, scalar_type, interior_facet=False): - """Initialise a kernel builder. - - :arg interior_facet: kernel accesses two cells - """ - super().__init__(scalar_type=scalar_type, interior_facet=interior_facet) - - # Cell orientation - if self.interior_facet: - cell_orientations = gem.Variable("cell_orientations", (2,), dtype=gem.uint_type) - self._cell_orientations = (gem.Indexed(cell_orientations, (0,)), - gem.Indexed(cell_orientations, (1,))) - else: - cell_orientations = gem.Variable("cell_orientations", (1,), dtype=gem.uint_type) - self._cell_orientations = (gem.Indexed(cell_orientations, (0,)),) + def __init__(self, scalar_type): + """Initialise a kernel builder.""" + super().__init__(scalar_type=scalar_type) def _coefficient(self, coefficient, name): """Prepare a coefficient. Adds glue code for the coefficient @@ -94,24 +93,58 @@ def _coefficient(self, coefficient, name): :arg name: coefficient name :returns: GEM expression representing the coefficient """ - expr = prepare_coefficient(coefficient, name, interior_facet=self.interior_facet) + expr = prepare_coefficient(coefficient, name, self._domain_integral_type_map) self.coefficient_map[coefficient] = expr return expr - def set_coordinates(self, domain): - """Prepare the coordinate field. + def set_coordinates(self, domains): + """Set coordinates for each domain. + + Parameters + ---------- + domains : list or tuple + All domains in the form. - :arg domain: :class:`ufl.Domain` """ # Create a fake coordinate coefficient for a domain. - f = Coefficient(FunctionSpace(domain, domain.ufl_coordinate_element())) - self.domain_coordinate[domain] = f - self._coefficient(f, "coords") + for i, domain in enumerate(domains): + if isinstance(domain, MeshSequence): + raise RuntimeError("Found a MeshSequence") + f = Coefficient(FunctionSpace(domain, domain.ufl_coordinate_element())) + self.domain_coordinate[domain] = f + self._coefficient(f, f"coords_{i}") + + def set_cell_orientations(self, domains): + """Set cell orientations for each domain. + + Parameters + ---------- + domains : list or tuple + All domains in the form. + + """ + # Cell orientation + self._cell_orientations = {} + for i, domain in enumerate(domains): + integral_type = self._domain_integral_type_map[domain] + if integral_type is None: + # See comment in prepare_coefficient. + self._cell_orientations[domain] = None + elif integral_type.startswith("interior_facet"): + cell_orientations = gem.Variable(f"cell_orientations_{i}", (2,), dtype=gem.uint_type) + self._cell_orientations[domain] = (gem.Indexed(cell_orientations, (0,)), + gem.Indexed(cell_orientations, (1,))) + else: + cell_orientations = gem.Variable(f"cell_orientations_{i}", (1,), dtype=gem.uint_type) + self._cell_orientations[domain] = (gem.Indexed(cell_orientations, (0,)),) - def set_cell_sizes(self, domain): - """Setup a fake coefficient for "cell sizes". + def set_cell_sizes(self, domains): + """Setup a fake coefficient for "cell sizes" for each domain. - :arg domain: The domain of the integral. + Parameters + ---------- + domains : list or tuple + All domains in the form. This is required for scaling of derivative basis functions on physically mapped elements (Argyris, Bell, etc...). We need a @@ -121,13 +154,15 @@ def set_cell_sizes(self, domain): Should the domain have topological dimension 0 this does nothing. """ - if domain.ufl_cell().topological_dimension() > 0: - # Can't create P1 since only P0 is a valid finite element if - # topological_dimension is 0 and the concept of "cell size" - # is not useful for a vertex. - f = Coefficient(FunctionSpace(domain, FiniteElement("P", domain.ufl_cell(), 1))) - expr = prepare_coefficient(f, "cell_sizes", interior_facet=self.interior_facet) - self._cell_sizes = expr + self._cell_sizes = {} + for i, domain in enumerate(domains): + if domain.ufl_cell().topological_dimension() > 0: + # Can't create P1 since only P0 is a valid finite element if + # topological_dimension is 0 and the concept of "cell size" + # is not useful for a vertex. + f = Coefficient(FunctionSpace(domain, FiniteElement("P", domain.ufl_cell(), 1))) + expr = prepare_coefficient(f, f"cell_sizes_{i}", self._domain_integral_type_map) + self._cell_sizes[domain] = expr def create_element(self, element, **kwargs): """Create a FInAT element (suitable for tabulating with) given @@ -194,7 +229,7 @@ def set_coefficient_numbers(self, coefficient_numbers): def register_requirements(self, ir): """Inspect what is referenced by the IR that needs to be provided by the kernel interface.""" - self.oriented, self.cell_sizes, self.tabulations, _ = check_requirements(ir) + self.oriented, self.cell_sizes, self.tabulations = check_requirements(ir) def set_output(self, o): """Produce the kernel return argument""" @@ -214,10 +249,12 @@ def construct_kernel(self, impero_c, index_names, needs_external_coords, log=Fal """ args = [self.output_arg] if self.oriented: - funarg = self.generate_arg_from_expression(self._cell_orientations, dtype=numpy.int32) + cell_orientations, = tuple(self._cell_orientations.values()) + funarg = self.generate_arg_from_expression(cell_orientations, dtype=numpy.int32) args.append(kernel_args.CellOrientationsKernelArg(funarg)) if self.cell_sizes: - funarg = self.generate_arg_from_expression(self._cell_sizes) + cell_sizes, = tuple(self._cell_sizes.values()) + funarg = self.generate_arg_from_expression(cell_sizes) args.append(kernel_args.CellSizesKernelArg(funarg)) for _, expr in self.coefficient_map.items(): # coefficient_map is OrderedDict. @@ -249,48 +286,18 @@ class KernelBuilder(KernelBuilderBase, KernelBuilderMixin): def __init__(self, integral_data_info, scalar_type, diagonal=False): """Initialise a kernel builder.""" - integral_type = integral_data_info.integral_type - super(KernelBuilder, self).__init__(scalar_type, integral_type.startswith("interior_facet")) + super(KernelBuilder, self).__init__(scalar_type) self.fem_scalar_type = scalar_type - self.diagonal = diagonal self.local_tensor = None self.coefficient_number_index_map = OrderedDict() - - # Facet number - if integral_type in ['exterior_facet', 'exterior_facet_vert']: - facet = gem.Variable('facet', (1,), dtype=gem.uint_type) - self._entity_number = {None: gem.VariableIndex(gem.Indexed(facet, (0,)))} - facet_orientation = gem.Variable('facet_orientation', (1,), dtype=gem.uint_type) - self._entity_orientation = {None: gem.OrientationVariableIndex(gem.Indexed(facet_orientation, (0,)))} - elif integral_type in ['interior_facet', 'interior_facet_vert']: - facet = gem.Variable('facet', (2,), dtype=gem.uint_type) - self._entity_number = { - '+': gem.VariableIndex(gem.Indexed(facet, (0,))), - '-': gem.VariableIndex(gem.Indexed(facet, (1,))) - } - facet_orientation = gem.Variable('facet_orientation', (2,), dtype=gem.uint_type) - self._entity_orientation = { - '+': gem.OrientationVariableIndex(gem.Indexed(facet_orientation, (0,))), - '-': gem.OrientationVariableIndex(gem.Indexed(facet_orientation, (1,))) - } - elif integral_type == 'interior_facet_horiz': - self._entity_number = {'+': 1, '-': 0} - facet_orientation = gem.Variable('facet_orientation', (1,), dtype=gem.uint_type) # base mesh entity orientation - self._entity_orientation = { - '+': gem.OrientationVariableIndex(gem.Indexed(facet_orientation, (0,))), - '-': gem.OrientationVariableIndex(gem.Indexed(facet_orientation, (0,))) - } - - self.set_arguments(integral_data_info.arguments) self.integral_data_info = integral_data_info + self._domain_integral_type_map = integral_data_info.domain_integral_type_map # For consistency with ExpressionKernelBuilder. + self.set_arguments() - def set_arguments(self, arguments): - """Process arguments. - - :arg arguments: :class:`ufl.Argument`s - :returns: GEM expression representing the return variable - """ + def set_arguments(self): + """Process arguments.""" + arguments = self.integral_data_info.arguments argument_multiindices = tuple(create_element(arg.ufl_element()).get_indices() for arg in arguments) if self.diagonal: @@ -301,11 +308,69 @@ def set_arguments(self, arguments): argument_multiindices = (a, a) return_variables = prepare_arguments(arguments, argument_multiindices, - interior_facet=self.interior_facet, + self.integral_data_info.domain_integral_type_map, diagonal=self.diagonal) self.return_variables = return_variables self.argument_multiindices = argument_multiindices + def set_entity_numbers(self, domains): + """Set entity numbers for each domain. + + Parameters + ---------- + domains : list or tuple + All domains in the form. + + """ + self._entity_numbers = {} + for i, domain in enumerate(domains): + # Facet number + integral_type = self.integral_data_info.domain_integral_type_map[domain] + if integral_type in ['exterior_facet', 'exterior_facet_vert']: + facet = gem.Variable(f'facet_{i}', (1,), dtype=gem.uint_type) + self._entity_numbers[domain] = {None: gem.VariableIndex(gem.Indexed(facet, (0,))), } + elif integral_type in ['interior_facet', 'interior_facet_vert']: + facet = gem.Variable(f'facet_{i}', (2,), dtype=gem.uint_type) + self._entity_numbers[domain] = { + '+': gem.VariableIndex(gem.Indexed(facet, (0,))), + '-': gem.VariableIndex(gem.Indexed(facet, (1,))) + } + elif integral_type == 'interior_facet_horiz': + self._entity_numbers[domain] = {'+': 1, '-': 0} + else: + self._entity_numbers[domain] = {None: None} + + def set_entity_orientations(self, domains): + """Set entity orientations for each domain. + + Parameters + ---------- + domains : list or tuple + All domains in the form. + + """ + self._entity_orientations = {} + for i, domain in enumerate(domains): + integral_type = self.integral_data_info.domain_integral_type_map[domain] + variable_name = f"entity_orientations_{i}" + if integral_type in ['exterior_facet', 'exterior_facet_vert']: + o = gem.Variable(variable_name, (1,), dtype=gem.uint_type) + self._entity_orientations[domain] = {None: gem.OrientationVariableIndex(gem.Indexed(o, (0,))), } + elif integral_type in ['interior_facet', 'interior_facet_vert']: + o = gem.Variable(variable_name, (2,), dtype=gem.uint_type) + self._entity_orientations[domain] = { + '+': gem.OrientationVariableIndex(gem.Indexed(o, (0,))), + '-': gem.OrientationVariableIndex(gem.Indexed(o, (1,))) + } + elif integral_type == 'interior_facet_horiz': + o = gem.Variable(variable_name, (1,), dtype=gem.uint_type) # base mesh entity orientation + self._entity_orientations[domain] = { + '+': gem.OrientationVariableIndex(gem.Indexed(o, (0,))), + '-': gem.OrientationVariableIndex(gem.Indexed(o, (0,))) + } + else: + self._entity_orientations[domain] = {None: None} + def set_coefficients(self): """Prepare the coefficients of the form.""" info = self.integral_data_info @@ -342,7 +407,7 @@ def construct_kernel(self, name, ctx, log=False): :arg log: bool if the Kernel should be profiled with Log events :returns: :class:`Kernel` object """ - impero_c, oriented, needs_cell_sizes, tabulations, active_variables, need_facet_orientation = self.compile_gem(ctx) + impero_c, _, _, tabulations, active_variables = self.compile_gem(ctx) if impero_c is None: return self.construct_empty_kernel(name) info = self.integral_data_info @@ -358,50 +423,80 @@ def construct_kernel(self, name, ctx, log=False): # Add return arg funarg = self.generate_arg_from_expression(self.return_variables) args = [kernel_args.OutputKernelArg(funarg)] - # Add coordinates arg - coord = self.domain_coordinate[info.domain] - expr = self.coefficient_map[coord] - funarg = self.generate_arg_from_expression(expr) - args.append(kernel_args.CoordinatesKernelArg(funarg)) - if oriented: - funarg = self.generate_arg_from_expression(self._cell_orientations, dtype=numpy.int32) - args.append(kernel_args.CellOrientationsKernelArg(funarg)) - if needs_cell_sizes: - funarg = self.generate_arg_from_expression(self._cell_sizes) - args.append(kernel_args.CellSizesKernelArg(funarg)) + active_domain_numbers_coordinates, args_ = self.make_active_domain_numbers({d: self.coefficient_map[c] for d, c in self.domain_coordinate.items()}, + active_variables, + kernel_args.CoordinatesKernelArg) + args.extend(args_) + active_domain_numbers_cell_orientations, args_ = self.make_active_domain_numbers(self._cell_orientations, + active_variables, + kernel_args.CellOrientationsKernelArg, + dtype=numpy.int32) + args.extend(args_) + active_domain_numbers_cell_sizes, args_ = self.make_active_domain_numbers(self._cell_sizes, + active_variables, + kernel_args.CellSizesKernelArg) + args.extend(args_) coefficient_indices = OrderedDict() for coeff, (number, index) in self.coefficient_number_index_map.items(): a = coefficient_indices.setdefault(number, []) expr = self.coefficient_map[coeff] + if expr is None: + # See comment in prepare_coefficient. + continue var, = gem.extract_type(expr if isinstance(expr, tuple) else (expr, ), gem.Variable) if var in active_variables: funarg = self.generate_arg_from_expression(expr) args.append(kernel_args.CoefficientKernelArg(funarg)) a.append(index) - - # now constants for gemexpr in self.constant_map.values(): funarg = self.generate_arg_from_expression(gemexpr) args.append(kernel_args.ConstantKernelArg(funarg)) - coefficient_indices = tuple(tuple(v) for v in coefficient_indices.values()) assert len(coefficient_indices) == len(info.coefficient_numbers) - if info.integral_type in ["exterior_facet", "exterior_facet_vert"]: - ext_loopy_arg = lp.GlobalArg("facet", numpy.uint32, shape=(1,)) - args.append(kernel_args.ExteriorFacetKernelArg(ext_loopy_arg)) - elif info.integral_type in ["interior_facet", "interior_facet_vert"]: - int_loopy_arg = lp.GlobalArg("facet", numpy.uint32, shape=(2,)) - args.append(kernel_args.InteriorFacetKernelArg(int_loopy_arg)) - # The submesh PR will introduce a robust mechanism to check if a Variable - # is actually used in the final form of the expression, so there will be - # no need to get "need_facet_orientation" from self.compile_gem(). - if need_facet_orientation: - if info.integral_type == "exterior_facet": - ext_ornt_loopy_arg = lp.GlobalArg("facet_orientation", gem.uint_type, shape=(1,)) - args.append(kernel_args.ExteriorFacetOrientationKernelArg(ext_ornt_loopy_arg)) - elif info.integral_type == "interior_facet": - int_ornt_loopy_arg = lp.GlobalArg("facet_orientation", gem.uint_type, shape=(2,)) - args.append(kernel_args.InteriorFacetOrientationKernelArg(int_ornt_loopy_arg)) + ext_dict = {} + for domain, expr in self._entity_numbers.items(): + integral_type = info.domain_integral_type_map[domain] + ext_dict[domain] = expr[None].expression if integral_type in ["exterior_facet", "exterior_facet_vert"] else None + active_domain_numbers_exterior_facets, args_ = self.make_active_domain_numbers( + ext_dict, + active_variables, + kernel_args.ExteriorFacetKernelArg, + dtype=numpy.uint32, + ) + args.extend(args_) + int_dict = {} + for domain, expr in self._entity_numbers.items(): + integral_type = info.domain_integral_type_map[domain] + int_dict[domain] = expr['+'].expression if integral_type in ["interior_facet", "interior_facet_vert"] else None + active_domain_numbers_interior_facets, args_ = self.make_active_domain_numbers( + int_dict, + active_variables, + kernel_args.InteriorFacetKernelArg, + dtype=numpy.uint32, + ) + args.extend(args_) + ext_dict = {} + for domain, expr in self._entity_orientations.items(): + integral_type = info.domain_integral_type_map[domain] + ext_dict[domain] = expr[None].expression if integral_type in ["exterior_facet", "exterior_facet_vert"] else None + active_domain_numbers_orientations_exterior_facet, args_ = self.make_active_domain_numbers( + ext_dict, + active_variables, + kernel_args.OrientationsExteriorFacetKernelArg, + dtype=gem.uint_type, + ) + args.extend(args_) + int_dict = {} + for domain, expr in self._entity_orientations.items(): + integral_type = info.domain_integral_type_map[domain] + int_dict[domain] = expr['+'].expression if integral_type in ["interior_facet", "interior_facet_vert", "interior_facet_horiz"] else None + active_domain_numbers_orientations_interior_facet, args_ = self.make_active_domain_numbers( + int_dict, + active_variables, + kernel_args.OrientationsInteriorFacetKernelArg, + dtype=gem.uint_type, + ) + args.extend(args_) for name_, shape in tabulations: tab_loopy_arg = lp.GlobalArg(name_, dtype=self.scalar_type, shape=shape) args.append(kernel_args.TabulationKernelArg(tab_loopy_arg)) @@ -414,9 +509,16 @@ def construct_kernel(self, name, ctx, log=False): integral_type=info.integral_type, subdomain_id=info.subdomain_id, domain_number=info.domain_number, + active_domain_numbers=ActiveDomainNumbers( + coordinates=tuple(active_domain_numbers_coordinates), + cell_orientations=tuple(active_domain_numbers_cell_orientations), + cell_sizes=tuple(active_domain_numbers_cell_sizes), + exterior_facets=tuple(active_domain_numbers_exterior_facets), + interior_facets=tuple(active_domain_numbers_interior_facets), + orientations_exterior_facet=tuple(active_domain_numbers_orientations_exterior_facet), + orientations_interior_facet=tuple(active_domain_numbers_orientations_interior_facet), + ), coefficient_numbers=tuple(zip(info.coefficient_numbers, coefficient_indices)), - oriented=oriented, - needs_cell_sizes=needs_cell_sizes, tabulations=tabulations, flop_count=flop_count, name=name, @@ -429,3 +531,36 @@ def construct_empty_kernel(self, name): :returns: None """ return None + + def make_active_domain_numbers(self, domain_expr_dict, active_variables, kernel_arg_type, dtype=None): + """Make active domain numbers. + + Parameters + ---------- + domain_expr_dict : dict + Map from domains to expressions; must be ordered as extract_domains(form). + active_variables : tuple + Active variables in the DAG. + kernel_arg_type : KernelArg + Type of `KernelArg`. + dtype : numpy.dtype + dtype. + + Returns + ------- + tuple + Tuple of active domain numbers and corresponding kernel args. + + """ + active_dns = [] + args = [] + for i, expr in enumerate(domain_expr_dict.values()): + if expr is None: + var = None + else: + var, = gem.extract_type(expr if isinstance(expr, tuple) else (expr, ), gem.Variable) + if var in active_variables: + funarg = self.generate_arg_from_expression(expr, dtype=dtype) + args.append(kernel_arg_type(funarg)) + active_dns.append(i) + return tuple(active_dns), tuple(args) diff --git a/tsfc/ufl_utils.py b/tsfc/ufl_utils.py index 18173a9660..c26febd68e 100644 --- a/tsfc/ufl_utils.py +++ b/tsfc/ufl_utils.py @@ -40,6 +40,7 @@ def compute_form_data(form, do_apply_integral_scaling=True, do_apply_geometry_lowering=True, preserve_geometry_types=preserve_geometry_types, + do_apply_default_restrictions=True, do_apply_restrictions=True, do_estimate_degrees=True, coefficients_to_split=None, @@ -57,6 +58,7 @@ def compute_form_data(form, do_apply_integral_scaling=do_apply_integral_scaling, do_apply_geometry_lowering=do_apply_geometry_lowering, preserve_geometry_types=preserve_geometry_types, + do_apply_default_restrictions=do_apply_default_restrictions, do_apply_restrictions=do_apply_restrictions, do_estimate_degrees=do_estimate_degrees, do_replace_functions=True, @@ -166,6 +168,8 @@ def _modified_terminal(self, o): positive_restricted = _modified_terminal negative_restricted = _modified_terminal + single_value_restricted = _modified_terminal + to_be_restricted = _modified_terminal reference_grad = _modified_terminal reference_value = _modified_terminal @@ -197,8 +201,11 @@ def modified_terminal(self, o): mt = analyse_modified_terminal(o) t = mt.terminal r = mt.restriction - if isinstance(t, Argument) and r != self.restrictions[t.number()]: - return Zero(o.ufl_shape, o.ufl_free_indices, o.ufl_index_dimensions) + if isinstance(t, Argument) and r in ['+', '-']: + if r == self.restrictions[t.number()]: + return o + else: + return Zero(o.ufl_shape, o.ufl_free_indices, o.ufl_index_dimensions) else: return o