Skip to content
25 changes: 2 additions & 23 deletions firedrake/adjoint/transformed_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,28 +176,6 @@ def is_dg_space(space: WithGeometry) -> bool:
return e.is_dg()


def dg_space(space: WithGeometry) -> WithGeometry:
"""Construct a DG space containing a given function space as a subspace.

Parameters
----------

space
A function space.

Returns
-------

firedrake.functionspaceimpl.WithGeometry
A DG space containing `space` as a subspace. May be `space`.
"""

if is_dg_space(space):
return space
else:
return fd.FunctionSpace(space.mesh(), finat.ufl.BrokenElement(space.ufl_element()))


class L2TransformedFunctional(AbstractReducedFunctional):
r"""Represents the functional

Expand Down Expand Up @@ -265,7 +243,8 @@ def __init__(self, functional: pyadjoint.OverloadedType, controls: Union[Control
self._space_D = Enlist(space_D)
if len(self._space_D) != len(self._space):
raise ValueError("Invalid length")
self._space_D = tuple(dg_space(space) if space_D is None else space_D
self._space_D = tuple((space if is_dg_space(space) else space.broken_space())
if space_D is None else space_D
for space, space_D in zip(self._space, self._space_D))

self._controls = tuple(Control(fd.Function(space_D), riesz_map="l2")
Expand Down
13 changes: 13 additions & 0 deletions firedrake/functionspaceimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,19 @@ def make_function_space(cls, mesh, element, name=None):
new = cls.create(new, mesh)
return new

def broken_space(self):
"""Return a :class:`.WithGeometryBase` with a :class:`finat.ufl.BrokenElement`
constructed from this function space's FiniteElement.

Returns
-------
WithGeometryBase :
The new function space with a :class:`~finat.ufl.BrokenElement`.
"""
return type(self).make_function_space(
self.mesh(), finat.ufl.BrokenElement(self.ufl_element()),
name=f"{self.name}_broken" if self.name else None)

def reconstruct(
self,
mesh: MeshGeometry | None = None,
Expand Down
4 changes: 1 addition & 3 deletions firedrake/slate/static_condensation/hybridization.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import functools

import ufl
import finat.ufl

import firedrake.dmhooks as dmhooks
from firedrake.slate.static_condensation.sc_base import SCBase
Expand Down Expand Up @@ -90,8 +89,7 @@ def initialize(self, pc):
TraceSpace = FunctionSpace(mesh[self.vidx], "HDiv Trace", tdegree)

# Break the function spaces and define fully discontinuous spaces
broken_elements = finat.ufl.MixedElement([finat.ufl.BrokenElement(Vi.ufl_element()) for Vi in V])
V_d = FunctionSpace(mesh, broken_elements)
V_d = V.broken_space()

# Set up the functions for the original, hybridized
# and schur complement systems
Expand Down
45 changes: 44 additions & 1 deletion tests/firedrake/regression/test_function_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def test_function_space_variant(mesh, space):


@pytest.mark.parametrize("modifier",
[BrokenElement, HDivElement,
[HDivElement,
HCurlElement])
@pytest.mark.parametrize("element",
[FiniteElement("CG", triangle, 1),
Expand Down Expand Up @@ -313,3 +313,46 @@ def test_reconstruct_sub_component(dg0, rt1, mesh, mesh2, dual):
assert is_primal(V1.parent) == is_primal(V2.parent) != dual
assert V1.parent.ufl_element() == V2.parent.ufl_element()
assert V1.parent.index == V2.parent.index == index


@pytest.mark.parametrize("family", ("CG", "BDM", "DG"))
@pytest.mark.parametrize("shape", (0, 2, (2, 3)), ids=("0", "2", "(2,3)"))
def test_broken_space(mesh, shape, family):
"""Check that FunctionSpace.broken_space returns the a
FunctionSpace with the correct element.
"""
kwargs = {"variant": "spectral"} if family == "DG" else {}

elem = FiniteElement(family, mesh.ufl_cell(), 1, **kwargs)

if not isinstance(shape, int):
make_element = lambda elem: TensorElement(elem, shape=shape)
elif shape > 0:
make_element = lambda elem: VectorElement(elem, dim=shape)
else:
make_element = lambda elem: elem

fs = FunctionSpace(mesh, make_element(elem))
broken = fs.broken_space()
expected = FunctionSpace(mesh, make_element(BrokenElement(elem)))

assert broken == expected


def test_mixed_broken_space(mesh):
"""Check that MixedFunctionSpace.broken_space returns the a
MixedFunctionSpace with the correct element.
"""

mixed_elem = MixedElement([
FiniteElement("CG", mesh.ufl_cell(), 1),
VectorElement("BDM", mesh.ufl_cell(), 2, dim=2),
TensorElement("DG", mesh.ufl_cell(), 1, shape=(2, 3), variant="spectral")
])
broken_elem = MixedElement([BrokenElement(elem) for elem in mixed_elem.sub_elements])

mfs = FunctionSpace(mesh, mixed_elem)
broken = mfs.broken_space()
expected = FunctionSpace(mesh, broken_elem)

assert broken == expected
Loading