From 926925dc46c12de02fdf6978208ed8beecfd9f48 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Tue, 31 Mar 2026 17:19:46 +0100 Subject: [PATCH 1/7] Submesh: support tuple subdomain_id --- firedrake/mesh.py | 22 ++++++++++++++++--- tests/firedrake/submesh/test_submesh_facet.py | 15 +++++++++++++ 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/firedrake/mesh.py b/firedrake/mesh.py index 6cac68ac58..ec570a6013 100644 --- a/firedrake/mesh.py +++ b/firedrake/mesh.py @@ -4806,10 +4806,11 @@ def Submesh(mesh, subdim=None, subdomain_id=None, label_name=None, name=None, ig subdim : int | None Topological dimension of the submesh. Defaults to ``mesh.topological_dimension``. - subdomain_id : int | None + subdomain_id : int | Sequence | None Subdomain ID representing the submesh. - If `None` the submesh will cover the entire domain. - This is useful to obtain a codim-1 submesh over all facets or + If multiple subdomain IDs are provided, their union is taken. + If `None` the submesh will cover the entire domain, + this is useful to obtain a codim-1 submesh over all facets or a submesh over a different communicator. label_name : str | None Name of the label to search ``subdomain_id`` in. @@ -4876,8 +4877,23 @@ def Submesh(mesh, subdim=None, subdomain_id=None, label_name=None, name=None, ig label_name = dmcommon.CELL_SETS_LABEL elif subdim == dim - 1: label_name = dmcommon.FACE_SETS_LABEL + + if isinstance(subdomain_id, (tuple, list)): + # A list of subdomain ids requires us to build an internal DM label with the union + iset = PETSc.IS().createGeneral([], comm=comm or mesh.comm) + for sub in subdomain_id: + iset = iset.union(plex.getStratumIS(label_name, sub)) + label_name = "temp_union" + subdomain_id = 1 + plex.createLabel(label_name) + label = plex.getLabel(label_name) + label.setStratumIS(subdomain_id, iset) + subplex = dmcommon.submesh_create(plex, subdim, label_name, subdomain_id, ignore_halo, comm=comm) + if label_name == "temp_union": + plex.removeLabel(label_name) + comm = comm or mesh.comm name = name or _generate_default_submesh_name(mesh.name) subplex.setName(_generate_default_mesh_topology_name(name)) diff --git a/tests/firedrake/submesh/test_submesh_facet.py b/tests/firedrake/submesh/test_submesh_facet.py index 83ac29c404..31ff89331c 100644 --- a/tests/firedrake/submesh/test_submesh_facet.py +++ b/tests/firedrake/submesh/test_submesh_facet.py @@ -134,3 +134,18 @@ def test_submesh_facet_all_facets(): rmesh = RelabeledMesh(mesh, [facet_function], [facet_value]) submesh2 = Submesh(rmesh, mesh.topological_dimension - 1, facet_value) assert submesh2.cell_set.size == submesh1.cell_set.size + + +def test_submesh_facet_subdomain_id_tuple(): + mesh = UnitCubeMesh(2, 2, 2) + subdomain_id = (1, 3, 6) + submesh1 = Submesh(mesh, mesh.topological_dimension - 1, subdomain_id=subdomain_id) + assert abs(assemble(1*dx(domain=submesh1)) - len(subdomain_id)) < 1E-12 + + V = FunctionSpace(mesh, "HDiv Trace", 0) + facet_function = Function(V) + DirichletBC(V, 1, subdomain_id).apply(facet_function) + facet_value = 999 + rmesh = RelabeledMesh(mesh, [facet_function], [facet_value]) + submesh2 = Submesh(rmesh, mesh.topological_dimension - 1, facet_value) + assert submesh2.cell_set.size == submesh1.cell_set.size From 75f5c95eea75c2166ec9005bd0a54b15f4cead66 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Wed, 1 Apr 2026 10:56:54 +0100 Subject: [PATCH 2/7] Take the intersection for nested lists of subdomain_ids --- firedrake/mesh.py | 30 ++++-- tests/firedrake/submesh/test_submesh_facet.py | 15 --- .../submesh/test_submesh_interface.py | 93 +++++++++++++++++++ 3 files changed, 115 insertions(+), 23 deletions(-) create mode 100644 tests/firedrake/submesh/test_submesh_interface.py diff --git a/firedrake/mesh.py b/firedrake/mesh.py index ec570a6013..82a3b25b7e 100644 --- a/firedrake/mesh.py +++ b/firedrake/mesh.py @@ -25,7 +25,7 @@ from pyop2.mpi import ( MPI, COMM_WORLD, temp_internal_comm ) -from functools import cached_property +from functools import cached_property, reduce from pyop2.utils import as_tuple import petsctools from petsctools import OptionsManager, get_external_packages @@ -4809,6 +4809,7 @@ def Submesh(mesh, subdim=None, subdomain_id=None, label_name=None, name=None, ig subdomain_id : int | Sequence | None Subdomain ID representing the submesh. If multiple subdomain IDs are provided, their union is taken. + If nested lists of subdomain IDs are provided, their intersection is taken. If `None` the submesh will cover the entire domain, this is useful to obtain a codim-1 submesh over all facets or a submesh over a different communicator. @@ -4878,12 +4879,25 @@ def Submesh(mesh, subdim=None, subdomain_id=None, label_name=None, name=None, ig elif subdim == dim - 1: label_name = dmcommon.FACE_SETS_LABEL - if isinstance(subdomain_id, (tuple, list)): - # A list of subdomain ids requires us to build an internal DM label with the union - iset = PETSc.IS().createGeneral([], comm=comm or mesh.comm) + # Parse non-integer subdomain_id + if subdomain_id == "on_boundary": + subdomain_id = tuple(mesh.exterior_facets.unique_markers) + + if isinstance(subdomain_id, Sequence): + # Create a temporary DMLabel with the union of the labels in the list + icomm = comm or mesh.comm + iset = PETSc.IS().createGeneral([], comm=icomm) for sub in subdomain_id: - iset = iset.union(plex.getStratumIS(label_name, sub)) - label_name = "temp_union" + if isinstance(sub, Sequence): + # Take the intersection of the (closure of the) labels from nested lists + ises = [plex.getStratumIS(label_name, subi) for subi in sub] + closure = [[plex.getTransitiveClosure(p)[0] for p in i.indices] for i in ises] + indices = reduce(np.intersect1d, closure) + cur = PETSc.IS().createGeneral(indices, comm=icomm) + else: + cur = plex.getStratumIS(label_name, sub) + iset = iset.union(cur) + label_name = "temp_label" subdomain_id = 1 plex.createLabel(label_name) label = plex.getLabel(label_name) @@ -4891,7 +4905,7 @@ def Submesh(mesh, subdim=None, subdomain_id=None, label_name=None, name=None, ig subplex = dmcommon.submesh_create(plex, subdim, label_name, subdomain_id, ignore_halo, comm=comm) - if label_name == "temp_union": + if label_name == "temp_label": plex.removeLabel(label_name) comm = comm or mesh.comm @@ -4900,7 +4914,7 @@ def Submesh(mesh, subdim=None, subdomain_id=None, label_name=None, name=None, ig if subplex.getDimension() != subdim: raise RuntimeError(f"Found subplex dim ({subplex.getDimension()}) != expected ({subdim})") if reorder is None: - # Ideally we should set perm_is = mesh.dm_reordering[label_indices] + # Ideally we should set perm_is = mesh._dm_renumbering[label_indices] reorder = mesh._did_reordering submesh = Mesh( diff --git a/tests/firedrake/submesh/test_submesh_facet.py b/tests/firedrake/submesh/test_submesh_facet.py index 31ff89331c..83ac29c404 100644 --- a/tests/firedrake/submesh/test_submesh_facet.py +++ b/tests/firedrake/submesh/test_submesh_facet.py @@ -134,18 +134,3 @@ def test_submesh_facet_all_facets(): rmesh = RelabeledMesh(mesh, [facet_function], [facet_value]) submesh2 = Submesh(rmesh, mesh.topological_dimension - 1, facet_value) assert submesh2.cell_set.size == submesh1.cell_set.size - - -def test_submesh_facet_subdomain_id_tuple(): - mesh = UnitCubeMesh(2, 2, 2) - subdomain_id = (1, 3, 6) - submesh1 = Submesh(mesh, mesh.topological_dimension - 1, subdomain_id=subdomain_id) - assert abs(assemble(1*dx(domain=submesh1)) - len(subdomain_id)) < 1E-12 - - V = FunctionSpace(mesh, "HDiv Trace", 0) - facet_function = Function(V) - DirichletBC(V, 1, subdomain_id).apply(facet_function) - facet_value = 999 - rmesh = RelabeledMesh(mesh, [facet_function], [facet_value]) - submesh2 = Submesh(rmesh, mesh.topological_dimension - 1, facet_value) - assert submesh2.cell_set.size == submesh1.cell_set.size diff --git a/tests/firedrake/submesh/test_submesh_interface.py b/tests/firedrake/submesh/test_submesh_interface.py new file mode 100644 index 0000000000..c96cf5f574 --- /dev/null +++ b/tests/firedrake/submesh/test_submesh_interface.py @@ -0,0 +1,93 @@ +import pytest +import numpy as np +from firedrake import * + + +def test_submesh_subdomain_id_tuple(): + mesh = UnitSquareMesh(4, 4) + x, y = SpatialCoordinate(mesh) + M = FunctionSpace(mesh, "DG", 0) + m1 = Function(M).interpolate(conditional(lt(x, 0.5), 1, 0)) + m2 = Function(M).interpolate(conditional(lt(y, 0.5), 1, 0)) + mesh.mark_entities(m1, 111) + mesh.mark_entities(m2, 222) + + subdomain_id = [111, 222] + submesh1 = Submesh(mesh, mesh.topological_dimension, subdomain_id=subdomain_id) + + m3 = Function(M).interpolate(m1 + m2 - m1 * m2) + expected = assemble(m3*dx) + assert abs(assemble(1*dx(domain=submesh1)) - expected) < 1E-12 + + mesh.mark_entities(m3, 333) + submesh2 = Submesh(mesh, mesh.topological_dimension, 333) + assert submesh2.cell_set.size == submesh1.cell_set.size + assert np.allclose(submesh2.coordinates.dat.data, submesh1.coordinates.dat.data) + + +def test_submesh_subdomain_id_nested_tuple(): + mesh = UnitSquareMesh(4, 4) + x, y = SpatialCoordinate(mesh) + M = FunctionSpace(mesh, "DG", 0) + m1 = Function(M).interpolate(conditional(lt(x, 0.5), 1, 0)) + m2 = Function(M).interpolate(conditional(lt(y, 0.5), 1, 0)) + mesh.mark_entities(m1, 111) + mesh.mark_entities(m2, 222) + + subdomain_id = [(111, 222)] + submesh1 = Submesh(mesh, mesh.topological_dimension, subdomain_id=subdomain_id) + + m3 = Function(M).interpolate(m1 * m2) + expected = assemble(m3*dx) + assert abs(assemble(1*dx(domain=submesh1)) - expected) < 1E-12 + + mesh.mark_entities(m3, 333) + submesh2 = Submesh(mesh, mesh.topological_dimension, 333) + assert submesh2.cell_set.size == submesh1.cell_set.size + assert np.allclose(submesh2.coordinates.dat.data, submesh1.coordinates.dat.data) + + +@pytest.mark.parametrize("subdomain_id", ["on_boundary", (1, 3, 6)]) +def test_submesh_facet_subdomain_id_tuple(subdomain_id): + mesh = UnitCubeMesh(2, 2, 2) + submesh1 = Submesh(mesh, mesh.topological_dimension - 1, subdomain_id=subdomain_id) + if subdomain_id == "on_boundary": + area = assemble(1*ds(domain=mesh)) + else: + area = assemble(1*ds(subdomain_id, domain=mesh)) + assert abs(assemble(1*dx(domain=submesh1)) - area) < 1E-12 + + V = FunctionSpace(mesh, "HDiv Trace", 0) + facet_function = Function(V) + DirichletBC(V, 1, subdomain_id).apply(facet_function) + facet_value = 999 + rmesh = RelabeledMesh(mesh, [facet_function], [facet_value]) + submesh2 = Submesh(rmesh, mesh.topological_dimension - 1, facet_value) + assert submesh2.cell_set.size == submesh1.cell_set.size + assert np.allclose(submesh2.coordinates.dat.data, submesh1.coordinates.dat.data) + + +def test_submesh_facet_subdomain_id_nested_tuple(): + mesh = UnitSquareMesh(4, 4) + x, y = SpatialCoordinate(mesh) + M = FunctionSpace(mesh, "DG", 0) + m1 = Function(M).interpolate(conditional(lt(x, 0.5), 1, 0)) + m2 = Function(M).interpolate(conditional(lt(x, 0.5), 0, 1)) + mesh.mark_entities(m1, 111) + mesh.mark_entities(m2, 222) + + subdomain_id = [(111, 222)] + submesh1 = Submesh(mesh, mesh.topological_dimension - 1, subdomain_id=subdomain_id, label_name="Cell Sets") + + expected = 1 + assert abs(assemble(1*dx(domain=submesh1)) - expected) < 1E-12 + + x, y = SpatialCoordinate(mesh) + V = FunctionSpace(mesh, "HDiv Trace", 0) + facet_function = Function(V) + facet_function.interpolate(conditional(lt(abs(x-0.5), 1E-8), 1, 0)) + facet_value = 999 + rmesh = RelabeledMesh(mesh, [facet_function], [facet_value]) + submesh2 = Submesh(rmesh, mesh.topological_dimension - 1, facet_value) + assert submesh2.cell_set.size == submesh1.cell_set.size + assert np.allclose(submesh2.coordinates.dat.data, submesh1.coordinates.dat.data) From 9d3903538d5151fc1917cc00ec33718a56fede30 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 2 Apr 2026 18:07:37 +0100 Subject: [PATCH 3/7] review comments --- firedrake/cython/dmcommon.pyx | 42 ++++++++++++++++++- firedrake/cython/petschdr.pxi | 1 + firedrake/mesh.py | 18 +++++--- .../submesh/test_submesh_interface.py | 8 ++-- 4 files changed, 57 insertions(+), 12 deletions(-) diff --git a/firedrake/cython/dmcommon.pyx b/firedrake/cython/dmcommon.pyx index a4d4d92460..3b21d8df8b 100644 --- a/firedrake/cython/dmcommon.pyx +++ b/firedrake/cython/dmcommon.pyx @@ -2203,11 +2203,11 @@ def _get_expanded_dm_dg_coords(dm: PETSc.DM, ndofs: np.ndarray): def _get_periodicity(dm: PETSc.DM) -> tuple[tuple[bool, bool], ...]: """Return mesh periodicity information. - + This function returns a 2-tuple of bools per dimension where the first entry indicates whether the mesh is periodic in that dimension, and the second indicates whether the mesh is single-cell periodic in that dimension. - + """ cdef: const PetscReal *maxCell, *L @@ -4325,3 +4325,41 @@ def get_dm_cell_types(PETSc.DM dm): return tuple( polytope_type_enum for polytope_type_enum, found in enumerate(found_all) if found ) + + +def create_label_intersection(PETSc.DM dm, label_name, label_values): + """Return the intersection of the closure of a subdomains of a DMPlex. + + Parameters + ---------- + dm : PETSc.DM + The DMPlex. + label_name : str + The name of the label + label_values : Sequence[int] + The values of the subdomain label to intersect + + Returns + ------- + tuple + A PETSc.IS with the points in the intersection. + + """ + cdef: + PETSc.DMLabel label + PETSc.PetscIS is1, is2 + PetscInt val = label_values[0] + + label = dm.getLabel(label_name) + CHKERR(DMPlexLabelComplete(dm.dm, label.dmlabel)) + CHKERR(DMLabelGetStratumIS(label.dmlabel, val, &is1)) + + for i in range(1, len(label_values)): + iout = PETSc.IS() + val = label_values[i] + CHKERR(DMLabelGetStratumIS(label.dmlabel, val, &is2)) + CHKERR(ISIntersect(is1, is2, &(iout).iset)) + CHKERR(ISDestroy(&is1)) + CHKERR(ISDestroy(&is2)) + is1 = (iout).iset + return iout diff --git a/firedrake/cython/petschdr.pxi b/firedrake/cython/petschdr.pxi index 42ac97e24d..445f5dace7 100644 --- a/firedrake/cython/petschdr.pxi +++ b/firedrake/cython/petschdr.pxi @@ -142,6 +142,7 @@ cdef extern from "petscis.h" nogil: PetscErrorCode ISLocalToGlobalMappingGetBlockIndices(PETSc.PetscLGMap, const PetscInt**) PetscErrorCode ISLocalToGlobalMappingRestoreBlockIndices(PETSc.PetscLGMap, const PetscInt**) PetscErrorCode ISDestroy(PETSc.PetscIS*) + PetscErrorCode ISIntersect(PETSc.PetscIS, PETSc.PetscIS, PETSc.PetscIS*) cdef extern from "petscsf.h" nogil: struct PetscSFNode_: diff --git a/firedrake/mesh.py b/firedrake/mesh.py index 82a3b25b7e..ceecde0788 100644 --- a/firedrake/mesh.py +++ b/firedrake/mesh.py @@ -4863,6 +4863,8 @@ def Submesh(mesh, subdim=None, subdomain_id=None, label_name=None, name=None, ig raise NotImplementedError("Can not create a submesh of a ``VertexOnlyMesh``") if subdim is None: subdim = mesh.topological_dimension + if subdomain_id == "on_boundary": + subdim = subdim - 1 plex = mesh.topology_dm dim = plex.getDimension() if subdim not in {dim, dim - 1}: @@ -4880,20 +4882,24 @@ def Submesh(mesh, subdim=None, subdomain_id=None, label_name=None, name=None, ig label_name = dmcommon.FACE_SETS_LABEL # Parse non-integer subdomain_id - if subdomain_id == "on_boundary": - subdomain_id = tuple(mesh.exterior_facets.unique_markers) + if isinstance(subdomain_id, str): + if subdomain_id == "on_boundary": + subdomain_id = tuple(mesh.exterior_facets.unique_markers) + else: + raise ValueError(f"Submesh construction got invalid subdomain_id {subdomain_id}.") if isinstance(subdomain_id, Sequence): # Create a temporary DMLabel with the union of the labels in the list icomm = comm or mesh.comm iset = PETSc.IS().createGeneral([], comm=icomm) for sub in subdomain_id: + try: + sub, = sub + except ValueError: + pass if isinstance(sub, Sequence): # Take the intersection of the (closure of the) labels from nested lists - ises = [plex.getStratumIS(label_name, subi) for subi in sub] - closure = [[plex.getTransitiveClosure(p)[0] for p in i.indices] for i in ises] - indices = reduce(np.intersect1d, closure) - cur = PETSc.IS().createGeneral(indices, comm=icomm) + cur = dmcommon.create_label_intersection(plex, label_name, sub) else: cur = plex.getStratumIS(label_name, sub) iset = iset.union(cur) diff --git a/tests/firedrake/submesh/test_submesh_interface.py b/tests/firedrake/submesh/test_submesh_interface.py index c96cf5f574..9edde80aed 100644 --- a/tests/firedrake/submesh/test_submesh_interface.py +++ b/tests/firedrake/submesh/test_submesh_interface.py @@ -3,7 +3,7 @@ from firedrake import * -def test_submesh_subdomain_id_tuple(): +def test_submesh_subdomain_id_union(): mesh = UnitSquareMesh(4, 4) x, y = SpatialCoordinate(mesh) M = FunctionSpace(mesh, "DG", 0) @@ -25,7 +25,7 @@ def test_submesh_subdomain_id_tuple(): assert np.allclose(submesh2.coordinates.dat.data, submesh1.coordinates.dat.data) -def test_submesh_subdomain_id_nested_tuple(): +def test_submesh_subdomain_id_intersection(): mesh = UnitSquareMesh(4, 4) x, y = SpatialCoordinate(mesh) M = FunctionSpace(mesh, "DG", 0) @@ -48,7 +48,7 @@ def test_submesh_subdomain_id_nested_tuple(): @pytest.mark.parametrize("subdomain_id", ["on_boundary", (1, 3, 6)]) -def test_submesh_facet_subdomain_id_tuple(subdomain_id): +def test_submesh_facet_subdomain_id_union(subdomain_id): mesh = UnitCubeMesh(2, 2, 2) submesh1 = Submesh(mesh, mesh.topological_dimension - 1, subdomain_id=subdomain_id) if subdomain_id == "on_boundary": @@ -67,7 +67,7 @@ def test_submesh_facet_subdomain_id_tuple(subdomain_id): assert np.allclose(submesh2.coordinates.dat.data, submesh1.coordinates.dat.data) -def test_submesh_facet_subdomain_id_nested_tuple(): +def test_submesh_facet_subdomain_id_intersection(): mesh = UnitSquareMesh(4, 4) x, y = SpatialCoordinate(mesh) M = FunctionSpace(mesh, "DG", 0) From 3c5f20d2f2e3382407fca76632406fe9b2731c81 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 2 Apr 2026 23:46:21 +0100 Subject: [PATCH 4/7] fixes --- firedrake/cython/dmcommon.pyx | 24 +++++++++++------------- firedrake/mesh.py | 9 ++------- 2 files changed, 13 insertions(+), 20 deletions(-) diff --git a/firedrake/cython/dmcommon.pyx b/firedrake/cython/dmcommon.pyx index 3b21d8df8b..240810240d 100644 --- a/firedrake/cython/dmcommon.pyx +++ b/firedrake/cython/dmcommon.pyx @@ -4328,7 +4328,7 @@ def get_dm_cell_types(PETSc.DM dm): def create_label_intersection(PETSc.DM dm, label_name, label_values): - """Return the intersection of the closure of a subdomains of a DMPlex. + """Return the intersection of the closure of subdomains of a DMPlex. Parameters ---------- @@ -4341,25 +4341,23 @@ def create_label_intersection(PETSc.DM dm, label_name, label_values): Returns ------- - tuple + PETSc.IS A PETSc.IS with the points in the intersection. """ cdef: + PETSc.IS iout, i1, i2 PETSc.DMLabel label - PETSc.PetscIS is1, is2 - PetscInt val = label_values[0] + + if len(label_values) == 0: + return PETSc.IS().createGeneral([], comm=dm.comm) label = dm.getLabel(label_name) CHKERR(DMPlexLabelComplete(dm.dm, label.dmlabel)) - CHKERR(DMLabelGetStratumIS(label.dmlabel, val, &is1)) - - for i in range(1, len(label_values)): + iout = label.getStratumIS(label_values[0]) + for val in label_values[1:]: + i1 = iout + i2 = label.getStratumIS(val) iout = PETSc.IS() - val = label_values[i] - CHKERR(DMLabelGetStratumIS(label.dmlabel, val, &is2)) - CHKERR(ISIntersect(is1, is2, &(iout).iset)) - CHKERR(ISDestroy(&is1)) - CHKERR(ISDestroy(&is2)) - is1 = (iout).iset + CHKERR(ISIntersect(i1.iset, i2.iset, &iout.iset)) return iout diff --git a/firedrake/mesh.py b/firedrake/mesh.py index ceecde0788..5e92cb9c77 100644 --- a/firedrake/mesh.py +++ b/firedrake/mesh.py @@ -4890,15 +4890,10 @@ def Submesh(mesh, subdim=None, subdomain_id=None, label_name=None, name=None, ig if isinstance(subdomain_id, Sequence): # Create a temporary DMLabel with the union of the labels in the list - icomm = comm or mesh.comm - iset = PETSc.IS().createGeneral([], comm=icomm) + iset = PETSc.IS().createGeneral([], comm=mesh.comm) for sub in subdomain_id: - try: - sub, = sub - except ValueError: - pass if isinstance(sub, Sequence): - # Take the intersection of the (closure of the) labels from nested lists + # Take the intersection of the labels from nested lists cur = dmcommon.create_label_intersection(plex, label_name, sub) else: cur = plex.getStratumIS(label_name, sub) From 95b12192b942cbad4b4e67bd7302595460764742 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Mon, 6 Apr 2026 20:13:17 +0100 Subject: [PATCH 5/7] add a bunch of examples --- firedrake/mesh.py | 86 +++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 76 insertions(+), 10 deletions(-) diff --git a/firedrake/mesh.py b/firedrake/mesh.py index 5e92cb9c77..777c628009 100644 --- a/firedrake/mesh.py +++ b/firedrake/mesh.py @@ -25,7 +25,7 @@ from pyop2.mpi import ( MPI, COMM_WORLD, temp_internal_comm ) -from functools import cached_property, reduce +from functools import cached_property from pyop2.utils import as_tuple import petsctools from petsctools import OptionsManager, get_external_packages @@ -4854,6 +4854,65 @@ def Submesh(mesh, subdim=None, subdomain_id=None, label_name=None, name=None, ig ridges to be contained in the quad mesh are shared by at most two facets to make the quad mesh orientation algorithm work. + Examples + -------- + + Mark a cell subdomain and construct a codim-0 submesh from all cells in the subdomain + + >>> mesh = UnitSquareMesh(4, 4) + >>> x, y = SpatialCoordinate(mesh) + >>> DG = FunctionSpace(mesh, "DG", 0) + >>> cell_marker = assemble(interpolate(conditional(lt(x, 0.5), 1, 0), DG)) + >>> mesh.mark_entities(cell_marker, 111) + >>> submesh = Submesh(mesh, subdomain_id=111) + + Mark a facet subdomain and construct a codim-1 submesh from all facets in the subdomain + + >>> mesh = UnitSquareMesh(4, 4) + >>> x, y = SpatialCoordinate(mesh) + >>> DGT = FunctionSpace(mesh, "DGT", 0) + >>> facet_marker = assemble(interpolate(conditional(lt(abs(x-0.5), 1E-12), 1, 0), DGT)) + >>> mesh.mark_entities(facet_marker, 222) + >>> submesh = Submesh(mesh, dim=mesh.topological_dimension-1, subdomain_id=222) + + Construct a codim-0 submesh of the union of multiple subdomains by passing a list + + >>> mesh = UnitSquareMesh(4, 4) + >>> x, y = SpatialCoordinate(mesh) + >>> DG = FunctionSpace(mesh, "DG", 0) + >>> mesh.mark_entities(assemble(interpolate(conditional(lt(x, 0.5), 1, 0), DG)), 1) + >>> mesh.mark_entities(assemble(interpolate(conditional(lt(y, 0.5), 1, 0), DG)), 2) + >>> submesh = Submesh(mesh, subdomain_id=[1, 2]) + + Construct a codim-0 submesh of the intersection of multiple subdomains by passing a nested list + + >>> mesh = UnitSquareMesh(4, 4) + >>> x, y = SpatialCoordinate(mesh) + >>> DG = FunctionSpace(mesh, "DG", 0) + >>> mesh.mark_entities(assemble(interpolate(conditional(lt(x, 0.5), 1, 0), DG)), 1) + >>> mesh.mark_entities(assemble(interpolate(conditional(lt(y, 0.5), 1, 0), DG)), 2) + >>> submesh = Submesh(mesh, subdomain_id=[(1, 2)]) + + Construct a codim-1 submesh of all the facets (the skeleton mesh) + + >>> mesh = UnitSquareMesh(4, 4) + >>> submesh = Submesh(mesh, subdim=1) + + Construct a codim-1 submesh of the entire boundary + + >>> mesh = UnitSquareMesh(4, 4) + >>> submesh = Submesh(mesh, subdomain_id="on_boundary") + + Construct a codim-1 submesh of the union of multiple boundaries + + >>> mesh = UnitSquareMesh(4, 4) + >>> submesh = Submesh(mesh, subdim=mesh.topological_dimension-1, subdomain_id=[1, 2, 3]) + + Construct a codim-0 submesh of the part of the mesh owned by each MPI rank + + >>> mesh = UnitSquareMesh(4, 4) + >>> submesh = Submesh(mesh, comm=COMM_SELF) + """ if not isinstance(mesh, MeshGeometry): raise TypeError("Parent mesh must be a `MeshGeometry`") @@ -4861,10 +4920,20 @@ def Submesh(mesh, subdim=None, subdomain_id=None, label_name=None, name=None, ig raise NotImplementedError("Can not create a submesh of an ``ExtrudedMesh``") elif isinstance(mesh.topology, VertexOnlyMeshTopology): raise NotImplementedError("Can not create a submesh of a ``VertexOnlyMesh``") + + if subdomain_id == "on_boundary": + if subdim is None: + subdim = mesh.topological_dimension - 1 + elif subdim != mesh.topological_dimension - 1: + raise ValueError('subdomain_id="on_boundary" requires subdim=dim-1') + if label_name is None: + label_name = "exterior_facets" + elif label_name != "exterior_facets": + raise ValueError('subdomain_id="on_boundary" requires label_name="exterior_facets"') + subdomain_id = 1 + if subdim is None: subdim = mesh.topological_dimension - if subdomain_id == "on_boundary": - subdim = subdim - 1 plex = mesh.topology_dm dim = plex.getDimension() if subdim not in {dim, dim - 1}: @@ -4883,13 +4952,9 @@ def Submesh(mesh, subdim=None, subdomain_id=None, label_name=None, name=None, ig # Parse non-integer subdomain_id if isinstance(subdomain_id, str): - if subdomain_id == "on_boundary": - subdomain_id = tuple(mesh.exterior_facets.unique_markers) - else: - raise ValueError(f"Submesh construction got invalid subdomain_id {subdomain_id}.") - - if isinstance(subdomain_id, Sequence): - # Create a temporary DMLabel with the union of the labels in the list + raise ValueError(f"Submesh construction got invalid subdomain_id {subdomain_id}.") + elif isinstance(subdomain_id, Sequence): + # Take the union of the labels in the list iset = PETSc.IS().createGeneral([], comm=mesh.comm) for sub in subdomain_id: if isinstance(sub, Sequence): @@ -4898,6 +4963,7 @@ def Submesh(mesh, subdim=None, subdomain_id=None, label_name=None, name=None, ig else: cur = plex.getStratumIS(label_name, sub) iset = iset.union(cur) + # Create a temporary label label_name = "temp_label" subdomain_id = 1 plex.createLabel(label_name) From df1eb17a8256b8c67a6c7fc3766161de66f094c2 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Fri, 10 Apr 2026 11:02:54 +0100 Subject: [PATCH 6/7] cleanup --- firedrake/cython/dmcommon.pyx | 33 ++++++++-------------------- firedrake/mesh.py | 41 +++++++++++++++-------------------- 2 files changed, 26 insertions(+), 48 deletions(-) diff --git a/firedrake/cython/dmcommon.pyx b/firedrake/cython/dmcommon.pyx index 240810240d..5b4c7280fb 100644 --- a/firedrake/cython/dmcommon.pyx +++ b/firedrake/cython/dmcommon.pyx @@ -4327,37 +4327,22 @@ def get_dm_cell_types(PETSc.DM dm): ) -def create_label_intersection(PETSc.DM dm, label_name, label_values): - """Return the intersection of the closure of subdomains of a DMPlex. +def intersectIS(PETSc.IS i1, PETSc.IS i2): + """Return the intersection of two IS objects. Parameters ---------- - dm : PETSc.DM - The DMPlex. - label_name : str - The name of the label - label_values : Sequence[int] - The values of the subdomain label to intersect + i1 : PETSc.IS + The first IS. + i2 : PETSc.IS + The second IS. Returns ------- PETSc.IS - A PETSc.IS with the points in the intersection. + A PETSc.IS with the intersection. """ - cdef: - PETSc.IS iout, i1, i2 - PETSc.DMLabel label - - if len(label_values) == 0: - return PETSc.IS().createGeneral([], comm=dm.comm) - - label = dm.getLabel(label_name) - CHKERR(DMPlexLabelComplete(dm.dm, label.dmlabel)) - iout = label.getStratumIS(label_values[0]) - for val in label_values[1:]: - i1 = iout - i2 = label.getStratumIS(val) - iout = PETSc.IS() - CHKERR(ISIntersect(i1.iset, i2.iset, &iout.iset)) + cdef PETSc.IS iout = PETSc.IS() + CHKERR(ISIntersect(i1.iset, i2.iset, &iout.iset)) return iout diff --git a/firedrake/mesh.py b/firedrake/mesh.py index 777c628009..992f5d2b52 100644 --- a/firedrake/mesh.py +++ b/firedrake/mesh.py @@ -25,7 +25,7 @@ from pyop2.mpi import ( MPI, COMM_WORLD, temp_internal_comm ) -from functools import cached_property +from functools import cached_property, reduce from pyop2.utils import as_tuple import petsctools from petsctools import OptionsManager, get_external_packages @@ -4856,62 +4856,48 @@ def Submesh(mesh, subdim=None, subdomain_id=None, label_name=None, name=None, ig Examples -------- - - Mark a cell subdomain and construct a codim-0 submesh from all cells in the subdomain - >>> mesh = UnitSquareMesh(4, 4) >>> x, y = SpatialCoordinate(mesh) >>> DG = FunctionSpace(mesh, "DG", 0) + >>> DGT = FunctionSpace(mesh, "DGT", 0) + + Mark a cell subdomain and construct a codim-0 submesh from all cells in the subdomain + >>> cell_marker = assemble(interpolate(conditional(lt(x, 0.5), 1, 0), DG)) >>> mesh.mark_entities(cell_marker, 111) >>> submesh = Submesh(mesh, subdomain_id=111) Mark a facet subdomain and construct a codim-1 submesh from all facets in the subdomain - >>> mesh = UnitSquareMesh(4, 4) - >>> x, y = SpatialCoordinate(mesh) - >>> DGT = FunctionSpace(mesh, "DGT", 0) >>> facet_marker = assemble(interpolate(conditional(lt(abs(x-0.5), 1E-12), 1, 0), DGT)) >>> mesh.mark_entities(facet_marker, 222) - >>> submesh = Submesh(mesh, dim=mesh.topological_dimension-1, subdomain_id=222) + >>> submesh = Submesh(mesh, subdim=mesh.topological_dimension-1, subdomain_id=222) Construct a codim-0 submesh of the union of multiple subdomains by passing a list - >>> mesh = UnitSquareMesh(4, 4) - >>> x, y = SpatialCoordinate(mesh) - >>> DG = FunctionSpace(mesh, "DG", 0) >>> mesh.mark_entities(assemble(interpolate(conditional(lt(x, 0.5), 1, 0), DG)), 1) >>> mesh.mark_entities(assemble(interpolate(conditional(lt(y, 0.5), 1, 0), DG)), 2) >>> submesh = Submesh(mesh, subdomain_id=[1, 2]) Construct a codim-0 submesh of the intersection of multiple subdomains by passing a nested list - >>> mesh = UnitSquareMesh(4, 4) - >>> x, y = SpatialCoordinate(mesh) - >>> DG = FunctionSpace(mesh, "DG", 0) - >>> mesh.mark_entities(assemble(interpolate(conditional(lt(x, 0.5), 1, 0), DG)), 1) - >>> mesh.mark_entities(assemble(interpolate(conditional(lt(y, 0.5), 1, 0), DG)), 2) >>> submesh = Submesh(mesh, subdomain_id=[(1, 2)]) Construct a codim-1 submesh of all the facets (the skeleton mesh) - >>> mesh = UnitSquareMesh(4, 4) >>> submesh = Submesh(mesh, subdim=1) Construct a codim-1 submesh of the entire boundary - >>> mesh = UnitSquareMesh(4, 4) >>> submesh = Submesh(mesh, subdomain_id="on_boundary") Construct a codim-1 submesh of the union of multiple boundaries - >>> mesh = UnitSquareMesh(4, 4) >>> submesh = Submesh(mesh, subdim=mesh.topological_dimension-1, subdomain_id=[1, 2, 3]) Construct a codim-0 submesh of the part of the mesh owned by each MPI rank - >>> mesh = UnitSquareMesh(4, 4) - >>> submesh = Submesh(mesh, comm=COMM_SELF) + >>> submesh = Submesh(mesh, ignore_halo=True, comm=COMM_SELF) """ if not isinstance(mesh, MeshGeometry): @@ -4954,14 +4940,21 @@ def Submesh(mesh, subdim=None, subdomain_id=None, label_name=None, name=None, ig if isinstance(subdomain_id, str): raise ValueError(f"Submesh construction got invalid subdomain_id {subdomain_id}.") elif isinstance(subdomain_id, Sequence): + label = plex.getLabel(label_name) + if subdim != dim: + plex.labelComplete(label) # Take the union of the labels in the list iset = PETSc.IS().createGeneral([], comm=mesh.comm) for sub in subdomain_id: - if isinstance(sub, Sequence): + if isinstance(sub, str): + raise ValueError(f"Submesh construction got invalid subdomain_id {sub}.") + elif isinstance(sub, Sequence): # Take the intersection of the labels from nested lists - cur = dmcommon.create_label_intersection(plex, label_name, sub) + if len(sub) == 0: + continue + cur = reduce(dmcommon.intersectIS, map(label.getStratumIS, sub)) else: - cur = plex.getStratumIS(label_name, sub) + cur = label.getStratumIS(sub) iset = iset.union(cur) # Create a temporary label label_name = "temp_label" From 55a87300af01c40a838406e53c3e45808c9f77dc Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Fri, 10 Apr 2026 12:45:23 +0100 Subject: [PATCH 7/7] allow union/intersection with on_boundary --- firedrake/mesh.py | 17 +++++--- .../submesh/test_submesh_interface.py | 40 +++++++++++++------ 2 files changed, 39 insertions(+), 18 deletions(-) diff --git a/firedrake/mesh.py b/firedrake/mesh.py index 992f5d2b52..b48b33566f 100644 --- a/firedrake/mesh.py +++ b/firedrake/mesh.py @@ -4943,18 +4943,25 @@ def Submesh(mesh, subdim=None, subdomain_id=None, label_name=None, name=None, ig label = plex.getLabel(label_name) if subdim != dim: plex.labelComplete(label) + + def get_points(sub): + if sub == "on_boundary": + return plex.getStratumIS("exterior_facets", 1) + elif isinstance(sub, str): + raise ValueError(f"Submesh construction got invalid subdomain_id {sub}.") + else: + return label.getStratumIS(sub) + # Take the union of the labels in the list iset = PETSc.IS().createGeneral([], comm=mesh.comm) for sub in subdomain_id: - if isinstance(sub, str): - raise ValueError(f"Submesh construction got invalid subdomain_id {sub}.") - elif isinstance(sub, Sequence): + if isinstance(sub, Sequence) and not isinstance(sub, str): # Take the intersection of the labels from nested lists if len(sub) == 0: continue - cur = reduce(dmcommon.intersectIS, map(label.getStratumIS, sub)) + cur = reduce(dmcommon.intersectIS, map(get_points, sub)) else: - cur = label.getStratumIS(sub) + cur = get_points(sub) iset = iset.union(cur) # Create a temporary label label_name = "temp_label" diff --git a/tests/firedrake/submesh/test_submesh_interface.py b/tests/firedrake/submesh/test_submesh_interface.py index 9edde80aed..c8fc92071b 100644 --- a/tests/firedrake/submesh/test_submesh_interface.py +++ b/tests/firedrake/submesh/test_submesh_interface.py @@ -57,9 +57,9 @@ def test_submesh_facet_subdomain_id_union(subdomain_id): area = assemble(1*ds(subdomain_id, domain=mesh)) assert abs(assemble(1*dx(domain=submesh1)) - area) < 1E-12 - V = FunctionSpace(mesh, "HDiv Trace", 0) - facet_function = Function(V) - DirichletBC(V, 1, subdomain_id).apply(facet_function) + DGT = FunctionSpace(mesh, "DGT", 0) + facet_function = Function(DGT) + DirichletBC(DGT, 1, subdomain_id).apply(facet_function) facet_value = 999 rmesh = RelabeledMesh(mesh, [facet_function], [facet_value]) submesh2 = Submesh(rmesh, mesh.topological_dimension - 1, facet_value) @@ -67,25 +67,39 @@ def test_submesh_facet_subdomain_id_union(subdomain_id): assert np.allclose(submesh2.coordinates.dat.data, submesh1.coordinates.dat.data) -def test_submesh_facet_subdomain_id_intersection(): +@pytest.mark.parametrize("sub", ["cell-cell", "cell-boundary"]) +def test_submesh_facet_subdomain_id_intersection(sub): + if sub == "cell-cell": + # (x <= 0.5) & (x >= 0.5) + subdomain_id = [(111, 222)] + expected = 1 + elif sub == "cell-boundary": + # (x <= 0.5) & (x == 0 | y == 0 | y == 1) + subdomain_id = [(111, "on_boundary")] + expected = 2 + mesh = UnitSquareMesh(4, 4) x, y = SpatialCoordinate(mesh) - M = FunctionSpace(mesh, "DG", 0) - m1 = Function(M).interpolate(conditional(lt(x, 0.5), 1, 0)) - m2 = Function(M).interpolate(conditional(lt(x, 0.5), 0, 1)) + DG = FunctionSpace(mesh, "DG", 0) + DGT = FunctionSpace(mesh, "DGT", 0) + m1 = Function(DG).interpolate(conditional(lt(x, 0.5), 1, 0)) + m2 = Function(DG).interpolate(conditional(lt(x, 0.5), 0, 1)) mesh.mark_entities(m1, 111) mesh.mark_entities(m2, 222) - subdomain_id = [(111, 222)] submesh1 = Submesh(mesh, mesh.topological_dimension - 1, subdomain_id=subdomain_id, label_name="Cell Sets") - expected = 1 assert abs(assemble(1*dx(domain=submesh1)) - expected) < 1E-12 - x, y = SpatialCoordinate(mesh) - V = FunctionSpace(mesh, "HDiv Trace", 0) - facet_function = Function(V) - facet_function.interpolate(conditional(lt(abs(x-0.5), 1E-8), 1, 0)) + facet_function = Function(DGT) + if sub == "cell-cell": + facet_function.interpolate(conditional(lt(abs(x-0.5), 1E-8), 1, 0)) + elif sub == "cell-boundary": + facet_function.interpolate(conditional(lt(x, 0.5), 1, 0)) + bnd = Function(DGT) + DirichletBC(DGT, 1, "on_boundary").apply(bnd) + facet_function.dat.data[:] *= bnd.dat.data_ro[:] + facet_value = 999 rmesh = RelabeledMesh(mesh, [facet_function], [facet_value]) submesh2 = Submesh(rmesh, mesh.topological_dimension - 1, facet_value)