diff --git a/firedrake/adjoint/transformed_functional.py b/firedrake/adjoint/transformed_functional.py index d3eeb2f46b..ce739edf4d 100644 --- a/firedrake/adjoint/transformed_functional.py +++ b/firedrake/adjoint/transformed_functional.py @@ -176,28 +176,6 @@ def is_dg_space(space: WithGeometry) -> bool: return e.is_dg() -def dg_space(space: WithGeometry) -> WithGeometry: - """Construct a DG space containing a given function space as a subspace. - - Parameters - ---------- - - space - A function space. - - Returns - ------- - - firedrake.functionspaceimpl.WithGeometry - A DG space containing `space` as a subspace. May be `space`. - """ - - if is_dg_space(space): - return space - else: - return fd.FunctionSpace(space.mesh(), finat.ufl.BrokenElement(space.ufl_element())) - - class L2TransformedFunctional(AbstractReducedFunctional): r"""Represents the functional @@ -265,7 +243,8 @@ def __init__(self, functional: pyadjoint.OverloadedType, controls: Union[Control self._space_D = Enlist(space_D) if len(self._space_D) != len(self._space): raise ValueError("Invalid length") - self._space_D = tuple(dg_space(space) if space_D is None else space_D + self._space_D = tuple((space if is_dg_space(space) else space.broken_space()) + if space_D is None else space_D for space, space_D in zip(self._space, self._space_D)) self._controls = tuple(Control(fd.Function(space_D), riesz_map="l2") diff --git a/firedrake/functionspaceimpl.py b/firedrake/functionspaceimpl.py index 5b8d6b44a1..926824345f 100644 --- a/firedrake/functionspaceimpl.py +++ b/firedrake/functionspaceimpl.py @@ -403,6 +403,19 @@ def make_function_space(cls, mesh, element, name=None): new = cls.create(new, mesh) return new + def broken_space(self): + """Return a :class:`.WithGeometryBase` with a :class:`finat.ufl.BrokenElement` + constructed from this function space's FiniteElement. + + Returns + ------- + WithGeometryBase : + The new function space with a :class:`~finat.ufl.BrokenElement`. + """ + return type(self).make_function_space( + self.mesh(), finat.ufl.BrokenElement(self.ufl_element()), + name=f"{self.name}_broken" if self.name else None) + def reconstruct( self, mesh: MeshGeometry | None = None, diff --git a/firedrake/slate/static_condensation/hybridization.py b/firedrake/slate/static_condensation/hybridization.py index cfc72f2119..e66c4c547f 100644 --- a/firedrake/slate/static_condensation/hybridization.py +++ b/firedrake/slate/static_condensation/hybridization.py @@ -1,7 +1,6 @@ import functools import ufl -import finat.ufl import firedrake.dmhooks as dmhooks from firedrake.slate.static_condensation.sc_base import SCBase @@ -90,8 +89,7 @@ def initialize(self, pc): TraceSpace = FunctionSpace(mesh[self.vidx], "HDiv Trace", tdegree) # Break the function spaces and define fully discontinuous spaces - broken_elements = finat.ufl.MixedElement([finat.ufl.BrokenElement(Vi.ufl_element()) for Vi in V]) - V_d = FunctionSpace(mesh, broken_elements) + V_d = V.broken_space() # Set up the functions for the original, hybridized # and schur complement systems diff --git a/tests/firedrake/regression/test_function_spaces.py b/tests/firedrake/regression/test_function_spaces.py index ab22b66ddb..491354df19 100644 --- a/tests/firedrake/regression/test_function_spaces.py +++ b/tests/firedrake/regression/test_function_spaces.py @@ -120,7 +120,7 @@ def test_function_space_variant(mesh, space): @pytest.mark.parametrize("modifier", - [BrokenElement, HDivElement, + [HDivElement, HCurlElement]) @pytest.mark.parametrize("element", [FiniteElement("CG", triangle, 1), @@ -313,3 +313,46 @@ def test_reconstruct_sub_component(dg0, rt1, mesh, mesh2, dual): assert is_primal(V1.parent) == is_primal(V2.parent) != dual assert V1.parent.ufl_element() == V2.parent.ufl_element() assert V1.parent.index == V2.parent.index == index + + +@pytest.mark.parametrize("family", ("CG", "BDM", "DG")) +@pytest.mark.parametrize("shape", (0, 2, (2, 3)), ids=("0", "2", "(2,3)")) +def test_broken_space(mesh, shape, family): + """Check that FunctionSpace.broken_space returns the a + FunctionSpace with the correct element. + """ + kwargs = {"variant": "spectral"} if family == "DG" else {} + + elem = FiniteElement(family, mesh.ufl_cell(), 1, **kwargs) + + if not isinstance(shape, int): + make_element = lambda elem: TensorElement(elem, shape=shape) + elif shape > 0: + make_element = lambda elem: VectorElement(elem, dim=shape) + else: + make_element = lambda elem: elem + + fs = FunctionSpace(mesh, make_element(elem)) + broken = fs.broken_space() + expected = FunctionSpace(mesh, make_element(BrokenElement(elem))) + + assert broken == expected + + +def test_mixed_broken_space(mesh): + """Check that MixedFunctionSpace.broken_space returns the a + MixedFunctionSpace with the correct element. + """ + + mixed_elem = MixedElement([ + FiniteElement("CG", mesh.ufl_cell(), 1), + VectorElement("BDM", mesh.ufl_cell(), 2, dim=2), + TensorElement("DG", mesh.ufl_cell(), 1, shape=(2, 3), variant="spectral") + ]) + broken_elem = MixedElement([BrokenElement(elem) for elem in mixed_elem.sub_elements]) + + mfs = FunctionSpace(mesh, mixed_elem) + broken = mfs.broken_space() + expected = FunctionSpace(mesh, broken_elem) + + assert broken == expected